Merge branch 'main' into 4-27-app-deploy

This commit is contained in:
Stephen Zhou 2026-05-07 12:36:06 +08:00
commit b1773ed11f
No known key found for this signature in database
222 changed files with 6111 additions and 3570 deletions

View File

@ -32,12 +32,7 @@ class TagBindingPayload(BaseModel):
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1)
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
@ -75,7 +70,6 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam,
TagResponse,
)
@ -184,13 +178,13 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
@ -211,54 +205,15 @@ class TagBindingCollectionApi(Resource):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/<uuid:id>")
class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
class TagBindingRemoveApi(Resource):
"""Batch resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
@console_ns.doc("remove_tag_bindings")
@console_ns.doc(description="Remove one or more tag bindings from a target.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _remove_tag_binding()
return _remove_tag_bindings()

View File

@ -2,7 +2,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel):
class TagUnbindingPayload(BaseModel):
tag_id: str
"""Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally."""
tag_ids: list[str] = Field(default_factory=list)
tag_id: str | None = None
target_id: str
@model_validator(mode="before")
@classmethod
def normalize_legacy_tag_id(cls, data: object) -> object:
if not isinstance(data, dict):
return data
if not data.get("tag_ids") and data.get("tag_id"):
return {**data, "tag_ids": [data["tag_id"]]}
return data
@model_validator(mode="after")
def validate_tag_ids(self) -> "TagUnbindingPayload":
if not self.tag_ids:
raise ValueError("Tag IDs is required.")
return self
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc("unbind_dataset_tags")
@service_api_ns.doc(description="Unbind tags from a dataset")
@service_api_ns.doc(
responses={
204: "Tag unbound successfully",
204: "Tags unbound successfully",
401: "Unauthorized - invalid API token",
403: "Forbidden - insufficient permissions",
}
@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204

View File

@ -1,7 +1,9 @@
from flask import Flask
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
def init_app(app):
def init_app(app: Flask):
with app.app_context():
configure_session_factory(db.engine)

View File

@ -246,8 +246,18 @@ class TidbService:
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
if not cluster_info.qdrant_endpoint:
cluster_info.qdrant_endpoint = TidbService.extract_qdrant_endpoint(
item
) or TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
if cluster_info.qdrant_endpoint:
cluster_info.status = TidbAuthBindingStatus.ACTIVE
else:
logger.warning(
"Cluster %s is ACTIVE but qdrant endpoint is not ready; will retry later",
item["clusterId"],
)
db.session.add(cluster_info)
db.session.commit()
else:

View File

@ -1,8 +1,11 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
from models.enums import TidbAuthBindingStatus
class TestExtractQdrantEndpoint:
"""Unit tests for TidbService.extract_qdrant_endpoint."""
@ -216,3 +219,86 @@ class TestBatchCreateEdgeCases:
private_key="priv",
region="us-east-1",
)
class TestBatchUpdateTidbServerlessClusterStatus:
"""Verify that status updates only expose clusters after qdrant endpoint is ready."""
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_sets_active_when_batch_response_contains_endpoint(self, mock_http, mock_db):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {
"clusters": [
{
"clusterId": "c-1",
"state": "ACTIVE",
"userPrefix": "pfx",
"endpoints": {"public": {"host": "gw.tidbcloud.com"}},
}
]
},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
assert binding.status == TidbAuthBindingStatus.ACTIVE
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_fetches_endpoint_when_batch_response_omits_it(self, mock_http, mock_db, mock_fetch_endpoint):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
assert binding.status == TidbAuthBindingStatus.ACTIVE
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_keeps_creating_when_endpoint_is_not_ready(self, mock_http, mock_db, mock_fetch_endpoint):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint is None
assert binding.status == TidbAuthBindingStatus.CREATING
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()

View File

@ -174,7 +174,7 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.62.0",
"pyrefly>=0.64.0",
"xinference-client>=2.7.0",
]

View File

@ -1,9 +1,11 @@
import uuid
from typing import cast
import sqlalchemy as sa
from flask_login import current_user
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from sqlalchemy import delete, func, select
from sqlalchemy.engine import CursorResult
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
@ -29,7 +31,7 @@ class TagBindingCreatePayload(BaseModel):
class TagBindingDeletePayload(BaseModel):
tag_id: str
tag_ids: list[str] = Field(min_length=1)
target_id: str
type: TagType
@ -178,13 +180,18 @@ class TagService:
@staticmethod
def delete_tag_binding(payload: TagBindingDeletePayload):
TagService.check_target_exists(payload.type, payload.target_id)
tag_binding = db.session.scalar(
select(TagBinding)
.where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id)
.limit(1)
result = cast(
CursorResult,
db.session.execute(
delete(TagBinding).where(
TagBinding.target_id == payload.target_id,
TagBinding.tag_id.in_(payload.tag_ids),
TagBinding.tenant_id == current_user.current_tenant_id,
)
),
)
if tag_binding:
db.session.delete(tag_binding)
if result.rowcount:
db.session.commit()
@staticmethod

View File

@ -433,7 +433,7 @@ def flask_app_with_containers(set_up_containers_and_env) -> Flask:
@pytest.fixture
def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]:
def flask_req_ctx_with_containers(flask_app_with_containers: Flask) -> Generator[None, None, None]:
"""
Request context fixture for containerized Flask application.
@ -454,7 +454,7 @@ def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None,
@pytest.fixture
def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]:
def test_client_with_containers(flask_app_with_containers: Flask) -> Generator[FlaskClient, None, None]:
"""
Test client fixture for containerized Flask application.
@ -475,7 +475,7 @@ def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskCli
@pytest.fixture
def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]:
def db_session_with_containers(flask_app_with_containers: Flask) -> Generator[Session, None, None]:
"""
Database session fixture for containerized testing.

View File

@ -7,6 +7,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, NotFound
@ -69,7 +70,7 @@ def _unwrap(func):
class TestCompletionEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_completion_create_payload(self):
@ -86,7 +87,7 @@ class TestCompletionEndpoints:
)
assert payload.query == "hi"
def test_completion_api_success(self, app, monkeypatch):
def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
@ -116,7 +117,7 @@ class TestCompletionEndpoints:
assert resp == {"result": {"text": "ok"}}
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
@ -142,7 +143,7 @@ class TestCompletionEndpoints:
with pytest.raises(NotFound):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
@ -166,7 +167,7 @@ class TestCompletionEndpoints:
with pytest.raises(completion_module.ProviderNotInitializeError):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_quota_exceeded(self, app, monkeypatch):
def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
@ -193,10 +194,10 @@ class TestCompletionEndpoints:
class TestAppEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch):
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = app_module.AppApi()
method = _unwrap(api.put)
payload = {
@ -234,7 +235,7 @@ class TestAppEndpoints:
}
)
def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch):
def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = app_module.AppIconApi()
method = _unwrap(api.post)
payload = {
@ -266,7 +267,7 @@ class TestAppEndpoints:
class TestOpsTraceEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_ops_trace_query_basic(self):
@ -277,7 +278,7 @@ class TestOpsTraceEndpoints:
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
assert payload.tracing_config["api_key"] == "k"
def test_trace_app_config_get_empty(self, app, monkeypatch):
def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.get)
@ -292,7 +293,7 @@ class TestOpsTraceEndpoints:
assert result == {"has_not_configured": True}
def test_trace_app_config_post_invalid(self, app, monkeypatch):
def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.post)
@ -309,7 +310,7 @@ class TestOpsTraceEndpoints:
with pytest.raises(BadRequest):
method(app_id="app-1")
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.delete)
@ -326,7 +327,7 @@ class TestOpsTraceEndpoints:
class TestSiteEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_site_response_structure(self):
@ -337,7 +338,7 @@ class TestSiteEndpoints:
payload = AppSiteUpdatePayload(default_language="en-US")
assert payload.default_language == "en-US"
def test_app_site_update_post(self, app, monkeypatch):
def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = site_module.AppSite()
method = _unwrap(api.post)
@ -375,7 +376,7 @@ class TestSiteEndpoints:
assert isinstance(result, dict)
assert result["title"] == "My Site"
def test_app_site_access_token_reset(self, app, monkeypatch):
def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = site_module.AppSiteAccessTokenReset()
method = _unwrap(api.post)
@ -427,7 +428,7 @@ class TestWorkflowEndpoints:
class TestWorkflowAppLogEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_workflow_app_log_query(self):
@ -438,7 +439,7 @@ class TestWorkflowAppLogEndpoints:
query = WorkflowAppLogQuery(detail="true")
assert query.detail is True
def test_workflow_app_log_api_get(self, app, monkeypatch):
def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_app_log_module.WorkflowAppLogApi()
method = _unwrap(api.get)
@ -477,14 +478,14 @@ class TestWorkflowAppLogEndpoints:
class TestWorkflowDraftVariableEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_workflow_variable_creation(self):
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
assert payload.name == "var1"
def test_workflow_variable_collection_get(self, app, monkeypatch):
def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
method = _unwrap(api.get)
@ -529,7 +530,7 @@ class TestWorkflowDraftVariableEndpoints:
class TestWorkflowStatisticEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_workflow_statistic_time_range(self):
@ -541,7 +542,7 @@ class TestWorkflowStatisticEndpoints:
assert query.start is None
assert query.end is None
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
def test_workflow_daily_runs_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
@ -567,7 +568,7 @@ class TestWorkflowStatisticEndpoints:
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
def test_workflow_daily_terminals_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
@ -598,7 +599,7 @@ class TestWorkflowStatisticEndpoints:
class TestWorkflowTriggerEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_webhook_trigger_payload(self):
@ -608,7 +609,7 @@ class TestWorkflowTriggerEndpoints:
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
assert enable_payload.enable_trigger is True
def test_webhook_trigger_api_get(self, app, monkeypatch):
def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_trigger_module.WebhookTriggerApi()
method = _unwrap(api.get)

View File

@ -6,6 +6,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
@ -36,10 +37,10 @@ def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
class TestAppImportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -57,7 +58,7 @@ class TestAppImportApi:
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -75,7 +76,7 @@ class TestAppImportApi:
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -96,7 +97,7 @@ class TestAppImportApi:
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -121,7 +122,7 @@ class TestAppImportApi:
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -149,10 +150,10 @@ class TestAppImportApi:
class TestAppImportConfirmApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportConfirmApi()
method = _unwrap(api.post)
@ -172,10 +173,10 @@ class TestAppImportConfirmApi:
class TestAppImportCheckDependenciesApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportCheckDependenciesApi()
method = _unwrap(api.get)

View File

@ -6,6 +6,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.email_register import (
EmailRegisterCheckApi,
@ -16,7 +17,7 @@ from services.account_service import AccountService
@pytest.fixture
def app(flask_app_with_containers):
def app(flask_app_with_containers: Flask):
return flask_app_with_containers
@ -33,7 +34,7 @@ class TestEmailRegisterSendEmailApi:
mock_is_freeze,
mock_send_mail,
mock_get_account,
app,
app: Flask,
):
mock_send_mail.return_value = "token-123"
mock_is_freeze.return_value = False
@ -75,7 +76,7 @@ class TestEmailRegisterCheckApi:
mock_revoke,
mock_generate_token,
mock_reset_rate,
app,
app: Flask,
):
mock_rate_limit_check.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
@ -120,7 +121,7 @@ class TestEmailRegisterResetApi:
mock_create_account,
mock_login,
mock_reset_login_rate,
app,
app: Flask,
):
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
mock_create_account.return_value = MagicMock()

View File

@ -6,6 +6,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.forgot_password import (
ForgotPasswordCheckApi,
@ -16,7 +17,7 @@ from services.account_service import AccountService
@pytest.fixture
def app(flask_app_with_containers):
def app(flask_app_with_containers: Flask):
return flask_app_with_containers
@ -31,7 +32,7 @@ class TestForgotPasswordSendEmailApi:
mock_is_ip_limit,
mock_send_email,
mock_get_account,
app,
app: Flask,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
@ -80,7 +81,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
app: Flask,
):
mock_rate_limit_check.return_value = False
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
@ -123,7 +124,7 @@ class TestForgotPasswordResetApi:
mock_db,
mock_get_account,
mock_update_account,
app,
app: Flask,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
mock_account = MagicMock()

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import (
OAuthCallback,
@ -21,7 +22,7 @@ from services.errors.account import AccountRegisterError
class TestGetOAuthProviders:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.mark.parametrize(
@ -65,7 +66,7 @@ class TestOAuthLogin:
return OAuthLogin()
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -130,7 +131,7 @@ class TestOAuthCallback:
return OAuthCallback()
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -394,7 +395,7 @@ class TestOAuthCallback:
class TestAccountGeneration:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.error import (
EmailCodeError,
@ -25,7 +26,7 @@ class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -68,7 +69,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email.assert_called_once()
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app: Flask):
"""
Test password reset email blocked by IP rate limit.
@ -138,7 +139,7 @@ class TestForgotPasswordCheckApi:
"""Test cases for verifying password reset codes."""
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@ -221,7 +222,7 @@ class TestForgotPasswordCheckApi:
mock_reset_rate_limit.assert_called_once_with("user@example.com")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
def test_verify_code_rate_limited(self, mock_is_rate_limit, app: Flask):
"""
Test code verification blocked by rate limit.
@ -244,7 +245,7 @@ class TestForgotPasswordCheckApi:
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app: Flask):
"""
Test code verification with invalid token.
@ -267,7 +268,7 @@ class TestForgotPasswordCheckApi:
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app: Flask):
"""
Test code verification with mismatched email.
@ -292,7 +293,7 @@ class TestForgotPasswordCheckApi:
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app: Flask):
"""
Test code verification with incorrect code.
@ -321,7 +322,7 @@ class TestForgotPasswordResetApi:
"""Test cases for resetting password with verified token."""
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -375,7 +376,7 @@ class TestForgotPasswordResetApi:
mock_revoke_token.assert_called_once_with("valid_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_mismatch(self, mock_get_data, app):
def test_reset_password_mismatch(self, mock_get_data, app: Flask):
"""
Test password reset with mismatched passwords.
@ -397,7 +398,7 @@ class TestForgotPasswordResetApi:
api.post()
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_invalid_token(self, mock_get_data, app):
def test_reset_password_invalid_token(self, mock_get_data, app: Flask):
"""
Test password reset with invalid token.
@ -418,7 +419,7 @@ class TestForgotPasswordResetApi:
api.post()
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_wrong_phase(self, mock_get_data, app):
def test_reset_password_wrong_phase(self, mock_get_data, app: Flask):
"""
Test password reset with token not in reset phase.
@ -442,7 +443,7 @@ class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app: Flask):
"""
Test password reset for non-existent account.

View File

@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -26,7 +27,7 @@ def unwrap(func):
class TestPipelineTemplateListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
@ -50,7 +51,7 @@ class TestPipelineTemplateListApi:
class TestPipelineTemplateDetailApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
@ -115,7 +116,7 @@ class TestPipelineTemplateDetailApi:
class TestCustomizedPipelineTemplateApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_patch_success(self, app):
@ -193,7 +194,7 @@ class TestCustomizedPipelineTemplateApi:
class TestPublishCustomizedPipelineTemplateApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_post_success(self, app):

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
import services
@ -24,13 +25,13 @@ def unwrap(func):
class TestCreateRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def _valid_payload(self):
return {"yaml_content": "name: test"}
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@ -58,7 +59,7 @@ class TestCreateRagPipelineDatasetApi:
assert status == 201
assert response == import_info
def test_post_forbidden_non_editor(self, app):
def test_post_forbidden_non_editor(self, app: Flask):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@ -76,7 +77,7 @@ class TestCreateRagPipelineDatasetApi:
with pytest.raises(Forbidden):
method(api)
def test_post_dataset_name_duplicate(self, app):
def test_post_dataset_name_duplicate(self, app: Flask):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@ -101,7 +102,7 @@ class TestCreateRagPipelineDatasetApi:
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload(self, app):
def test_post_invalid_payload(self, app: Flask):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@ -122,10 +123,10 @@ class TestCreateRagPipelineDatasetApi:
class TestCreateEmptyRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
@ -152,7 +153,7 @@ class TestCreateEmptyRagPipelineDatasetApi:
assert status == 201
assert response == {"id": "ds-1"}
def test_post_forbidden_non_editor(self, app):
def test_post_forbidden_non_editor(self, app: Flask):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
@ -25,7 +26,7 @@ def unwrap(func):
class TestRagPipelineImportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def _payload(self, mode="create"):
@ -128,7 +129,7 @@ class TestRagPipelineImportApi:
class TestRagPipelineImportConfirmApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_confirm_success(self, app):
@ -190,7 +191,7 @@ class TestRagPipelineImportConfirmApi:
class TestRagPipelineImportCheckDependenciesApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
@ -219,7 +220,7 @@ class TestRagPipelineImportCheckDependenciesApi:
class TestRagPipelineExportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_with_include_secret(self, app):

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
@ -45,10 +46,10 @@ def unwrap(func):
class TestDraftWorkflowApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_draft_success(self, app):
def test_get_draft_success(self, app: Flask):
api = DraftRagPipelineApi()
method = unwrap(api.get)
@ -68,7 +69,7 @@ class TestDraftWorkflowApi:
result = method(api, pipeline)
assert result == workflow
def test_get_draft_not_exist(self, app):
def test_get_draft_not_exist(self, app: Flask):
api = DraftRagPipelineApi()
method = unwrap(api.get)
@ -86,7 +87,7 @@ class TestDraftWorkflowApi:
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_sync_hash_not_match(self, app):
def test_sync_hash_not_match(self, app: Flask):
api = DraftRagPipelineApi()
method = unwrap(api.post)
@ -111,7 +112,7 @@ class TestDraftWorkflowApi:
with pytest.raises(DraftWorkflowNotSync):
method(api, pipeline)
def test_sync_invalid_text_plain(self, app):
def test_sync_invalid_text_plain(self, app: Flask):
api = DraftRagPipelineApi()
method = unwrap(api.post)
@ -128,7 +129,7 @@ class TestDraftWorkflowApi:
response, status = method(api, pipeline)
assert status == 400
def test_restore_published_workflow_to_draft_success(self, app):
def test_restore_published_workflow_to_draft_success(self, app: Flask):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
@ -155,7 +156,7 @@ class TestDraftWorkflowApi:
assert result["result"] == "success"
assert result["hash"] == "restored-hash"
def test_restore_published_workflow_to_draft_not_found(self, app):
def test_restore_published_workflow_to_draft_not_found(self, app: Flask):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
@ -179,7 +180,7 @@ class TestDraftWorkflowApi:
with pytest.raises(NotFound):
method(api, pipeline, "published-workflow")
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app):
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
@ -211,10 +212,10 @@ class TestDraftWorkflowApi:
class TestDraftRunNodes:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_iteration_node_success(self, app):
def test_iteration_node_success(self, app: Flask):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
@ -240,7 +241,7 @@ class TestDraftRunNodes:
result = method(api, pipeline, "node")
assert result == {"ok": True}
def test_iteration_node_conversation_not_exists(self, app):
def test_iteration_node_conversation_not_exists(self, app: Flask):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
@ -262,7 +263,7 @@ class TestDraftRunNodes:
with pytest.raises(NotFound):
method(api, pipeline, "node")
def test_loop_node_success(self, app):
def test_loop_node_success(self, app: Flask):
api = RagPipelineDraftRunLoopNodeApi()
method = unwrap(api.post)
@ -290,10 +291,10 @@ class TestDraftRunNodes:
class TestPipelineRunApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_draft_run_success(self, app):
def test_draft_run_success(self, app: Flask):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
@ -325,7 +326,7 @@ class TestPipelineRunApis:
):
assert method(api, pipeline) == {"ok": True}
def test_draft_run_rate_limit(self, app):
def test_draft_run_rate_limit(self, app: Flask):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
@ -356,10 +357,10 @@ class TestPipelineRunApis:
class TestDraftNodeRun:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_execution_not_found(self, app):
def test_execution_not_found(self, app: Flask):
api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post)
@ -387,7 +388,7 @@ class TestDraftNodeRun:
class TestPublishedPipelineApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_publish_success(self, app, db_session_with_containers: Session):
@ -436,10 +437,10 @@ class TestPublishedPipelineApis:
class TestMiscApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_task_stop(self, app):
def test_task_stop(self, app: Flask):
api = RagPipelineTaskStopApi()
method = unwrap(api.post)
@ -460,7 +461,7 @@ class TestMiscApis:
stop_mock.assert_called_once()
assert result["result"] == "success"
def test_transform_forbidden(self, app):
def test_transform_forbidden(self, app: Flask):
api = RagPipelineTransformApi()
method = unwrap(api.post)
@ -476,7 +477,7 @@ class TestMiscApis:
with pytest.raises(Forbidden):
method(api, "ds1")
def test_recommended_plugins(self, app):
def test_recommended_plugins(self, app: Flask):
api = RagPipelineRecommendedPluginApi()
method = unwrap(api.get)
@ -496,10 +497,10 @@ class TestMiscApis:
class TestPublishedRagPipelineRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_published_run_success(self, app):
def test_published_run_success(self, app: Flask):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
@ -533,7 +534,7 @@ class TestPublishedRagPipelineRunApi:
result = method(api, pipeline)
assert result == {"ok": True}
def test_published_run_rate_limit(self, app):
def test_published_run_rate_limit(self, app: Flask):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
@ -565,10 +566,10 @@ class TestPublishedRagPipelineRunApi:
class TestDefaultBlockConfigApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_block_config_success(self, app):
def test_get_block_config_success(self, app: Flask):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
@ -587,7 +588,7 @@ class TestDefaultBlockConfigApi:
result = method(api, pipeline, "llm")
assert result == {"k": "v"}
def test_get_block_config_invalid_json(self, app):
def test_get_block_config_invalid_json(self, app: Flask):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
@ -600,10 +601,10 @@ class TestDefaultBlockConfigApi:
class TestPublishedAllRagPipelineApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_published_workflows_success(self, app):
def test_get_published_workflows_success(self, app: Flask):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
@ -629,7 +630,7 @@ class TestPublishedAllRagPipelineApi:
assert result["items"] == [{"id": "w1"}]
assert result["has_more"] is False
def test_get_published_workflows_forbidden(self, app):
def test_get_published_workflows_forbidden(self, app: Flask):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
@ -649,10 +650,10 @@ class TestPublishedAllRagPipelineApi:
class TestRagPipelineByIdApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
@ -682,7 +683,7 @@ class TestRagPipelineByIdApi:
assert result == workflow
def test_patch_no_fields(self, app):
def test_patch_no_fields(self, app: Flask):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
@ -700,7 +701,7 @@ class TestRagPipelineByIdApi:
result, status = method(api, pipeline, "w1")
assert status == 400
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = RagPipelineByIdApi()
method = unwrap(api.delete)
@ -720,7 +721,7 @@ class TestRagPipelineByIdApi:
workflow_service.delete_workflow.assert_called_once()
assert result == (None, 204)
def test_delete_active_workflow_rejected(self, app):
def test_delete_active_workflow_rejected(self, app: Flask):
api = RagPipelineByIdApi()
method = unwrap(api.delete)
@ -733,10 +734,10 @@ class TestRagPipelineByIdApi:
class TestRagPipelineWorkflowLastRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_last_run_success(self, app):
def test_last_run_success(self, app: Flask):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
@ -758,7 +759,7 @@ class TestRagPipelineWorkflowLastRunApi:
result = method(api, pipeline, "node1")
assert result == node_exec
def test_last_run_not_found(self, app):
def test_last_run_not_found(self, app: Flask):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
@ -780,10 +781,10 @@ class TestRagPipelineWorkflowLastRunApi:
class TestRagPipelineDatasourceVariableApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_set_datasource_variables_success(self, app):
def test_set_datasource_variables_success(self, app: Flask):
api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console.datasets import data_source
@ -51,7 +52,7 @@ def mock_engine():
class TestDataSourceApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
@ -188,7 +189,7 @@ class TestDataSourceApi:
class TestDataSourceNotionListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_credential_not_found(self, app, patch_tenant):
@ -323,7 +324,7 @@ class TestDataSourceNotionListApi:
class TestDataSourceNotionApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_preview_success(self, app, patch_tenant):
@ -381,7 +382,7 @@ class TestDataSourceNotionApi:
class TestDataSourceNotionDatasetSyncApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
@ -424,7 +425,7 @@ class TestDataSourceNotionDatasetSyncApi:
class TestDataSourceNotionDocumentSyncApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.conversation as conversation_module
@ -53,7 +54,7 @@ def user():
class TestConversationListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, chat_app, user):
@ -108,7 +109,7 @@ class TestConversationListApi:
class TestConversationApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_delete_success(self, app, chat_app, user):
@ -156,7 +157,7 @@ class TestConversationApi:
class TestConversationRenameApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_rename_success(self, app, chat_app, user):
@ -197,7 +198,7 @@ class TestConversationRenameApi:
class TestConversationPinApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_pin_success(self, app, chat_app, user):
@ -219,7 +220,7 @@ class TestConversationPinApi:
class TestConversationUnPinApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_unpin_success(self, app, chat_app, user):

View File

@ -6,6 +6,7 @@ import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.tool_providers import (
@ -60,7 +61,7 @@ def _mock_user_tenant():
@pytest.fixture
def client(flask_app_with_containers):
def client(flask_app_with_containers: Flask):
return flask_app_with_containers.test_client()
@ -147,10 +148,10 @@ class TestUtils:
class TestToolProviderListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ToolProviderListApi()
method = unwrap(api.get)
@ -170,10 +171,10 @@ class TestToolProviderListApi:
class TestBuiltinProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_list_tools(self, app):
def test_list_tools(self, app: Flask):
api = ToolBuiltinProviderListToolsApi()
method = unwrap(api.get)
@ -190,7 +191,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider") == [{"a": 1}]
def test_info(self, app):
def test_info(self, app: Flask):
api = ToolBuiltinProviderInfoApi()
method = unwrap(api.get)
@ -207,7 +208,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider") == {"x": 1}
def test_delete(self, app):
def test_delete(self, app: Flask):
api = ToolBuiltinProviderDeleteApi()
method = unwrap(api.post)
@ -224,7 +225,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider")["result"] == "success"
def test_add_invalid_type(self, app):
def test_add_invalid_type(self, app: Flask):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
@ -238,7 +239,7 @@ class TestBuiltinProviderApis:
with pytest.raises(ValueError):
method(api, "provider")
def test_add_success(self, app):
def test_add_success(self, app: Flask):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
@ -257,7 +258,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider")["id"] == 1
def test_update(self, app):
def test_update(self, app: Flask):
api = ToolBuiltinProviderUpdateApi()
method = unwrap(api.post)
@ -276,7 +277,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider")["ok"]
def test_get_credentials(self, app):
def test_get_credentials(self, app: Flask):
api = ToolBuiltinProviderGetCredentialsApi()
method = unwrap(api.get)
@ -293,7 +294,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider") == {"k": "v"}
def test_icon(self, app):
def test_icon(self, app: Flask):
api = ToolBuiltinProviderIconApi()
method = unwrap(api.get)
@ -307,7 +308,7 @@ class TestBuiltinProviderApis:
response = method(api, "provider")
assert response.mimetype == "image/png"
def test_credentials_schema(self, app):
def test_credentials_schema(self, app: Flask):
api = ToolBuiltinProviderCredentialsSchemaApi()
method = unwrap(api.get)
@ -324,7 +325,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider", "oauth2") == {"schema": {}}
def test_set_default_credential(self, app):
def test_set_default_credential(self, app: Flask):
api = ToolBuiltinProviderSetDefaultApi()
method = unwrap(api.post)
@ -341,7 +342,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider")["ok"]
def test_get_credential_info(self, app):
def test_get_credential_info(self, app: Flask):
api = ToolBuiltinProviderGetCredentialInfoApi()
method = unwrap(api.get)
@ -358,7 +359,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider") == {"info": "x"}
def test_get_oauth_client_schema(self, app):
def test_get_oauth_client_schema(self, app: Flask):
api = ToolBuiltinProviderGetOauthClientSchemaApi()
method = unwrap(api.get)
@ -378,10 +379,10 @@ class TestBuiltinProviderApis:
class TestApiProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_add(self, app):
def test_add(self, app: Flask):
api = ToolApiProviderAddApi()
method = unwrap(api.post)
@ -406,7 +407,7 @@ class TestApiProviderApis:
):
assert method(api)["id"] == 1
def test_remote_schema(self, app):
def test_remote_schema(self, app: Flask):
api = ToolApiProviderGetRemoteSchemaApi()
method = unwrap(api.get)
@ -423,7 +424,7 @@ class TestApiProviderApis:
):
assert method(api)["schema"] == "x"
def test_list_tools(self, app):
def test_list_tools(self, app: Flask):
api = ToolApiProviderListToolsApi()
method = unwrap(api.get)
@ -440,7 +441,7 @@ class TestApiProviderApis:
):
assert method(api) == [{"tool": 1}]
def test_update(self, app):
def test_update(self, app: Flask):
api = ToolApiProviderUpdateApi()
method = unwrap(api.post)
@ -468,7 +469,7 @@ class TestApiProviderApis:
):
assert method(api)["ok"]
def test_delete(self, app):
def test_delete(self, app: Flask):
api = ToolApiProviderDeleteApi()
method = unwrap(api.post)
@ -485,7 +486,7 @@ class TestApiProviderApis:
):
assert method(api)["result"] == "success"
def test_get(self, app):
def test_get(self, app: Flask):
api = ToolApiProviderGetApi()
method = unwrap(api.get)
@ -505,10 +506,10 @@ class TestApiProviderApis:
class TestWorkflowApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_create(self, app):
def test_create(self, app: Flask):
api = ToolWorkflowProviderCreateApi()
method = unwrap(api.post)
@ -534,7 +535,7 @@ class TestWorkflowApis:
):
assert method(api)["id"] == 1
def test_update_invalid(self, app):
def test_update_invalid(self, app: Flask):
api = ToolWorkflowProviderUpdateApi()
method = unwrap(api.post)
@ -560,7 +561,7 @@ class TestWorkflowApis:
result = method(api)
assert result["ok"]
def test_delete(self, app):
def test_delete(self, app: Flask):
api = ToolWorkflowProviderDeleteApi()
method = unwrap(api.post)
@ -577,7 +578,7 @@ class TestWorkflowApis:
):
assert method(api)["ok"]
def test_get_error(self, app):
def test_get_error(self, app: Flask):
api = ToolWorkflowProviderGetApi()
method = unwrap(api.get)
@ -594,10 +595,10 @@ class TestWorkflowApis:
class TestLists:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_builtin_list(self, app):
def test_builtin_list(self, app: Flask):
api = ToolBuiltinListApi()
method = unwrap(api.get)
@ -617,7 +618,7 @@ class TestLists:
):
assert method(api) == [{"x": 1}]
def test_api_list(self, app):
def test_api_list(self, app: Flask):
api = ToolApiListApi()
method = unwrap(api.get)
@ -637,7 +638,7 @@ class TestLists:
):
assert method(api) == [{"x": 1}]
def test_workflow_list(self, app):
def test_workflow_list(self, app: Flask):
api = ToolWorkflowListApi()
method = unwrap(api.get)
@ -660,10 +661,10 @@ class TestLists:
class TestLabels:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_labels(self, app):
def test_labels(self, app: Flask):
api = ToolLabelsApi()
method = unwrap(api.get)
@ -679,10 +680,10 @@ class TestLabels:
class TestOAuth:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_oauth_no_client(self, app):
def test_oauth_no_client(self, app: Flask):
api = ToolPluginOAuthApi()
method = unwrap(api.get)
@ -700,7 +701,7 @@ class TestOAuth:
with pytest.raises(Forbidden):
method(api, "provider")
def test_oauth_callback_no_cookie(self, app):
def test_oauth_callback_no_cookie(self, app: Flask):
api = ToolOAuthCallback()
method = unwrap(api.get)
@ -711,10 +712,10 @@ class TestOAuth:
class TestOAuthCustomClient:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_save_custom_client(self, app):
def test_save_custom_client(self, app: Flask):
api = ToolOAuthCustomClient()
method = unwrap(api.post)
@ -731,7 +732,7 @@ class TestOAuthCustomClient:
):
assert method(api, "provider")["ok"]
def test_get_custom_client(self, app):
def test_get_custom_client(self, app: Flask):
api = ToolOAuthCustomClient()
method = unwrap(api.get)
@ -748,7 +749,7 @@ class TestOAuthCustomClient:
):
assert method(api, "provider") == {"client_id": "x"}
def test_delete_custom_client(self, app):
def test_delete_custom_client(self, app: Flask):
api = ToolOAuthCustomClient()
method = unwrap(api.delete)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, Forbidden
from controllers.console.workspace.trigger_providers import (
@ -45,7 +46,7 @@ def mock_user():
class TestTriggerProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_icon_success(self, app):
@ -93,7 +94,7 @@ class TestTriggerProviderApis:
class TestTriggerSubscriptionListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_list_success(self, app):
@ -128,7 +129,7 @@ class TestTriggerSubscriptionListApi:
class TestTriggerSubscriptionBuilderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_create_builder(self, app):
@ -236,7 +237,7 @@ class TestTriggerSubscriptionBuilderApis:
class TestTriggerSubscriptionCrud:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_update_rename_only(self, app):
@ -342,7 +343,7 @@ class TestTriggerSubscriptionCrud:
class TestTriggerOAuthApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_oauth_authorize_success(self, app):
@ -480,7 +481,7 @@ class TestTriggerOAuthApis:
class TestTriggerOAuthClientManageApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_client(self, app):
@ -556,7 +557,7 @@ class TestTriggerOAuthClientManageApi:
class TestTriggerSubscriptionVerifyApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_verify_success(self, app):

View File

@ -18,6 +18,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
@ -217,10 +218,20 @@ class TestTagUnbindingPayload:
"""Test suite for TagUnbindingPayload Pydantic model."""
def test_payload_with_valid_data(self):
payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456")
assert payload.tag_id == "tag_123"
payload = TagUnbindingPayload(tag_ids=["tag_123"], target_id="dataset_456")
assert payload.tag_ids == ["tag_123"]
assert payload.target_id == "dataset_456"
def test_payload_normalizes_legacy_tag_id(self):
payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456")
assert payload.tag_ids == ["tag_123"]
assert payload.target_id == "dataset_456"
def test_payload_rejects_empty_tag_ids(self):
with pytest.raises(ValueError) as exc_info:
TagUnbindingPayload(tag_ids=[], target_id="dataset_456")
assert "Tag IDs is required" in str(exc_info.value)
# ---------------------------------------------------------------------------
# Helpers
@ -236,7 +247,7 @@ def _unwrap(method):
@pytest.fixture
def app(flask_app_with_containers):
def app(flask_app_with_containers: Flask):
# Uses the full containerised app so that Flask config, extensions, and
# blueprint registrations match production. Most tests mock the service
# layer to isolate controller logic; a few (e.g. test_list_tags_from_db)
@ -1012,6 +1023,36 @@ class TestDatasetTagUnbindingApiPost:
mock_current_user.is_dataset_editor = True
mock_tag_svc.delete_tag_binding.return_value = None
with app.test_request_context(
"/datasets/tags/unbinding",
method="POST",
json={"tag_ids": ["tag-1"], "target_id": "ds-1"},
):
api = DatasetTagUnbindingApi()
result = api.post(_=None)
assert result == ("", 204)
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with(
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge")
)
@patch("controllers.service_api.dataset.dataset.TagService")
@patch("controllers.service_api.dataset.dataset.current_user")
def test_unbind_legacy_tag_id_success(
self,
mock_current_user,
mock_tag_svc,
app,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
mock_current_user.__class__ = Account
mock_current_user.has_edit_permission = True
mock_current_user.is_dataset_editor = True
mock_tag_svc.delete_tag_binding.return_value = None
with app.test_request_context(
"/datasets/tags/unbinding",
method="POST",
@ -1024,7 +1065,7 @@ class TestDatasetTagUnbindingApiPost:
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with(
TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge")
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge")
)
@patch("controllers.service_api.dataset.dataset.current_user")
@ -1038,7 +1079,7 @@ class TestDatasetTagUnbindingApiPost:
with app.test_request_context(
"/datasets/tags/unbinding",
method="POST",
json={"tag_id": "tag-1", "target_id": "ds-1"},
json={"tag_ids": ["tag-1"], "target_id": "ds-1"},
):
api = DatasetTagUnbindingApi()
with pytest.raises(Forbidden):

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.conversation import (
@ -34,16 +35,16 @@ def _end_user() -> SimpleNamespace:
class TestConversationListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/conversations"):
with pytest.raises(NotChatAppError):
ConversationListApi().get(_completion_app(), _end_user())
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
def test_happy_path(self, mock_paginate: MagicMock, app) -> None:
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
conv_id = str(uuid4())
conv = SimpleNamespace(
id=conv_id,
@ -65,16 +66,16 @@ class TestConversationListApi:
class TestConversationApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}"):
with pytest.raises(NotChatAppError):
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.delete")
def test_delete_success(self, mock_delete: MagicMock, app) -> None:
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
@ -83,7 +84,7 @@ class TestConversationApi:
assert result["result"] == "success"
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
def test_delete_not_found(self, mock_delete: MagicMock, app) -> None:
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
with pytest.raises(NotFound, match="Conversation Not Exists"):
@ -92,17 +93,17 @@ class TestConversationApi:
class TestConversationRenameApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
with pytest.raises(NotChatAppError):
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.rename")
@patch("controllers.web.conversation.web_ns")
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "New Name", "auto_generate": False}
conv = SimpleNamespace(
@ -126,7 +127,7 @@ class TestConversationRenameApi:
side_effect=ConversationNotExistsError(),
)
@patch("controllers.web.conversation.web_ns")
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "X", "auto_generate": False}
@ -137,16 +138,16 @@ class TestConversationRenameApi:
class TestConversationPinApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.pin")
def test_pin_success(self, mock_pin: MagicMock, app) -> None:
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
@ -154,7 +155,7 @@ class TestConversationPinApi:
assert result["result"] == "success"
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
def test_pin_not_found(self, mock_pin: MagicMock, app) -> None:
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
with pytest.raises(NotFound):
@ -163,16 +164,16 @@ class TestConversationPinApi:
class TestConversationUnPinApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.unpin")
def test_unpin_success(self, mock_unpin: MagicMock, app) -> None:
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)

View File

@ -7,6 +7,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.forgot_password import (
ForgotPasswordCheckApi,
@ -29,7 +30,7 @@ def _patch_wraps():
class TestForgotPasswordSendEmailApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
@ -42,7 +43,7 @@ class TestForgotPasswordSendEmailApi:
mock_rate_limit,
mock_get_account,
mock_send_mail,
app,
app: Flask,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
@ -64,7 +65,7 @@ class TestForgotPasswordSendEmailApi:
class TestForgotPasswordCheckApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@ -81,7 +82,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
app: Flask,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"}
@ -117,7 +118,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
app: Flask,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"}
@ -142,7 +143,7 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@ -157,7 +158,7 @@ class TestForgotPasswordResetApi:
mock_db,
mock_get_account,
mock_update_account,
app,
app: Flask,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
mock_account = MagicMock()
@ -194,7 +195,7 @@ class TestForgotPasswordResetApi:
mock_db,
mock_token_bytes,
mock_hash_password,
app,
app: Flask,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
account = MagicMock()

View File

@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
@ -182,7 +183,7 @@ class TestValidateUserAccessibility:
class TestDecodeJwtToken:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True):

View File

@ -85,7 +85,7 @@ class TestPauseStatePersistenceLayerTestContainers:
return WorkflowRunService(engine)
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
def setup_test_data(self, db_session_with_containers: Session, file_service, workflow_run_service):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account
from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
@ -295,7 +295,7 @@ class TestPauseStatePersistenceLayerTestContainers:
generate_entity=entity,
)
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers: Session):
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -352,7 +352,7 @@ class TestPauseStatePersistenceLayerTestContainers:
assert isinstance(persisted_entity, WorkflowAppGenerateEntity)
assert persisted_entity.workflow_execution_id == self.test_workflow_run_id
def test_state_persistence_and_retrieval(self, db_session_with_containers):
def test_state_persistence_and_retrieval(self, db_session_with_containers: Session):
"""Test that pause state can be persisted and retrieved correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -402,7 +402,7 @@ class TestPauseStatePersistenceLayerTestContainers:
assert retrieved_state["node_run_steps"] == 10
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_database_transaction_handling(self, db_session_with_containers):
def test_database_transaction_handling(self, db_session_with_containers: Session):
"""Test that database transactions are handled correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -433,7 +433,7 @@ class TestPauseStatePersistenceLayerTestContainers:
assert pause_model.resumed_at is None
assert pause_model.state_object_key != ""
def test_file_storage_integration(self, db_session_with_containers):
def test_file_storage_integration(self, db_session_with_containers: Session):
"""Test integration with file storage system."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -467,7 +467,7 @@ class TestPauseStatePersistenceLayerTestContainers:
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_workflow_with_different_creators(self, db_session_with_containers):
def test_workflow_with_different_creators(self, db_session_with_containers: Session):
"""Test pause state with workflows created by different users."""
# Arrange - Create workflow with different creator
different_user_id = str(uuid.uuid4())
@ -532,7 +532,7 @@ class TestPauseStatePersistenceLayerTestContainers:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
def test_layer_ignores_non_pause_events(self, db_session_with_containers: Session):
"""Test that layer ignores non-pause events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -562,7 +562,7 @@ class TestPauseStatePersistenceLayerTestContainers:
).all()
assert len(pause_states) == 0
def test_layer_requires_initialization(self, db_session_with_containers):
def test_layer_requires_initialization(self, db_session_with_containers: Session):
"""Test that layer requires proper initialization before handling events."""
# Arrange
layer = self._create_pause_state_persistence_layer()

View File

@ -15,6 +15,7 @@ from uuid import uuid4
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
from extensions.ext_redis import redis_client
@ -40,7 +41,7 @@ class TestTenantIsolatedTaskQueueIntegration:
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker):
"""Create test tenant and account for testing."""
# Create account
account = Account(
@ -94,7 +95,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker):
"""Test that different tenants have isolated queues."""
tenant1, _ = test_tenant_and_account
@ -176,7 +177,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert len(remaining_tasks) == 2
assert remaining_tasks == ["task4", "task5"]
def test_push_and_pull_complex_objects(self, test_queue, fake):
def test_push_and_pull_complex_objects(self, test_queue, fake: Faker):
"""Test pushing and pulling complex object tasks."""
# Create complex task objects as dictionaries (not dataclass instances)
tasks = [
@ -218,7 +219,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert pulled_task["data"] == original_task["data"]
assert pulled_task["metadata"] == original_task["metadata"]
def test_mixed_task_types(self, test_queue, fake):
def test_mixed_task_types(self, test_queue, fake: Faker):
"""Test pushing and pulling mixed string and object tasks."""
string_task = "simple_string_task"
object_task = {
@ -267,7 +268,7 @@ class TestTenantIsolatedTaskQueueIntegration:
# Verify task key has expired
assert test_queue.get_task_key() is None
def test_large_task_batch(self, test_queue, fake):
def test_large_task_batch(self, test_queue, fake: Faker):
"""Test handling large batches of tasks."""
# Create large batch of tasks
large_batch = []
@ -292,7 +293,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert isinstance(task, dict)
assert task["index"] == i # FIFO order
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker):
"""Test concurrent operations on different queues."""
tenant, _ = test_tenant_and_account
@ -312,7 +313,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert tasks2 == ["task1_queue2", "task2_queue2"]
assert tasks1 != tasks2
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake: Faker):
"""Test TaskWrapper serialization and deserialization roundtrip."""
# Create complex nested data
complex_data = {
@ -346,7 +347,7 @@ class TestTenantIsolatedTaskQueueIntegration:
task = test_queue.pull_tasks(1)
assert task[0] == invalid_json_task
def test_real_world_batch_processing_scenario(self, test_queue, fake):
def test_real_world_batch_processing_scenario(self, test_queue, fake: Faker):
"""Test realistic batch processing scenario."""
# Simulate batch processing tasks
batch_tasks = []
@ -403,7 +404,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker):
"""Create test tenant and account for testing."""
# Create account
account = Account(
@ -435,7 +436,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
return tenant, account
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker):
"""
Test compatibility with legacy queues containing only string data.
@ -465,7 +466,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
assert pulled_tasks == expected_order
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker):
"""
Test complete migration scenario from legacy to new system.
@ -546,7 +547,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
assert task["tenant_id"] == tenant.id
assert task["processing_type"] == "new_system"
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker):
"""
Test error recovery when legacy queue contains malformed data.

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -15,7 +16,7 @@ from tests.test_containers_integration_tests.helpers import generate_valid_passw
class TestGetAvailableDatasetsIntegration:
def test_returns_datasets_with_available_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -77,7 +78,7 @@ class TestGetAvailableDatasetsIntegration:
assert result[0].name == dataset.name
def test_filters_out_datasets_with_only_archived_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -130,7 +131,7 @@ class TestGetAvailableDatasetsIntegration:
assert len(result) == 0
def test_filters_out_datasets_with_only_disabled_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -183,7 +184,7 @@ class TestGetAvailableDatasetsIntegration:
assert len(result) == 0
def test_filters_out_datasets_with_non_completed_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -236,7 +237,7 @@ class TestGetAvailableDatasetsIntegration:
assert len(result) == 0
def test_includes_external_datasets_without_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that external datasets are returned even with no available documents.
@ -280,7 +281,7 @@ class TestGetAvailableDatasetsIntegration:
assert result[0].id == dataset.id
assert result[0].provider == "external"
def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
def test_filters_by_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies):
# Arrange
fake = Faker()
@ -356,7 +357,7 @@ class TestGetAvailableDatasetsIntegration:
assert result[0].tenant_id == tenant1.id
def test_returns_empty_list_when_no_datasets_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -379,7 +380,9 @@ class TestGetAvailableDatasetsIntegration:
# Assert
assert result == []
def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies):
def test_returns_only_requested_dataset_ids(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -439,7 +442,7 @@ class TestGetAvailableDatasetsIntegration:
class TestKnowledgeRetrievalIntegration:
def test_knowledge_retrieval_with_available_datasets(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -507,7 +510,7 @@ class TestKnowledgeRetrievalIntegration:
assert isinstance(result, list)
def test_knowledge_retrieval_no_available_datasets(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -555,7 +558,7 @@ class TestKnowledgeRetrievalIntegration:
assert result == []
def test_knowledge_retrieval_rate_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()

View File

@ -5,6 +5,7 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_service import ApiKeyAuthService
@ -31,7 +32,7 @@ class TestApiKeyAuthService:
def mock_args(self, category, provider, mock_credentials) -> dict:
return {"category": category, "provider": provider, "credentials": mock_credentials}
def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False):
def _create_binding(self, db_session: Session, *, tenant_id, category, provider, credentials=None, disabled=False):
binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id,
category=category,
@ -44,7 +45,7 @@ class TestApiKeyAuthService:
return binding
def test_get_provider_auth_list_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider)
db_session_with_containers.expire_all()
@ -56,14 +57,16 @@ class TestApiKeyAuthService:
assert len(tenant_results) == 1
assert tenant_results[0].provider == provider
def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id):
def test_get_provider_auth_list_empty(
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id
):
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
tenant_results = [r for r in result if r.tenant_id == tenant_id]
assert tenant_results == []
def test_get_provider_auth_list_filters_disabled(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
self._create_binding(
db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True
@ -78,7 +81,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_success(
self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
self,
mock_encrypter,
mock_factory,
flask_app_with_containers,
db_session_with_containers: Session,
tenant_id,
mock_args,
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
@ -97,7 +106,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_validation_failed(
self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
self, mock_factory, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_args
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = False
@ -112,7 +121,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encrypts_api_key(
self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
self,
mock_encrypter,
mock_factory,
flask_app_with_containers,
db_session_with_containers: Session,
tenant_id,
mock_args,
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
@ -128,7 +143,13 @@ class TestApiKeyAuthService:
mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key)
def test_get_auth_credentials_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials
self,
flask_app_with_containers,
db_session_with_containers: Session,
tenant_id,
category,
provider,
mock_credentials,
):
self._create_binding(
db_session_with_containers,
@ -144,14 +165,14 @@ class TestApiKeyAuthService:
assert result == mock_credentials
def test_get_auth_credentials_not_found(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
assert result is None
def test_get_auth_credentials_json_parsing(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
self._create_binding(
@ -169,7 +190,7 @@ class TestApiKeyAuthService:
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
def test_delete_provider_auth_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
binding = self._create_binding(
db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider
@ -183,7 +204,9 @@ class TestApiKeyAuthService:
remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first()
assert remaining is None
def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id):
def test_delete_provider_auth_not_found(
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id
):
# Should not raise when binding not found
ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4()))

View File

@ -10,6 +10,7 @@ from uuid import uuid4
import httpx
import pytest
from sqlalchemy.orm import Session
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
@ -114,7 +115,7 @@ class TestAuthIntegration:
assert result2[0].tenant_id == tenant_id_2
def test_cross_tenant_access_prevention(
self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id_2, category
):
result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL)

View File

@ -12,6 +12,7 @@ from unittest.mock import create_autospec, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
@ -273,7 +274,9 @@ class TestDocumentServicePauseDocument:
"user_id": user_id,
}
def test_pause_document_waiting_state_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_pause_document_waiting_state_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful pause of document in waiting state.
@ -310,7 +313,7 @@ class TestDocumentServicePauseDocument:
mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True")
def test_pause_document_indexing_state_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful pause of document in indexing state.
@ -340,7 +343,9 @@ class TestDocumentServicePauseDocument:
assert document.is_paused is True
assert document.paused_by == mock_document_service_dependencies["user_id"]
def test_pause_document_parsing_state_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_pause_document_parsing_state_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful pause of document in parsing state.
@ -367,7 +372,9 @@ class TestDocumentServicePauseDocument:
db_session_with_containers.refresh(document)
assert document.is_paused is True
def test_pause_document_completed_state_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_pause_document_completed_state_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when trying to pause completed document.
@ -396,7 +403,9 @@ class TestDocumentServicePauseDocument:
db_session_with_containers.refresh(document)
assert document.is_paused is False
def test_pause_document_error_state_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_pause_document_error_state_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when trying to pause document in error state.
@ -467,7 +476,9 @@ class TestDocumentServiceRecoverDocument:
"recover_task": mock_task,
}
def test_recover_document_paused_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_recover_document_paused_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful recovery of paused document.
@ -510,7 +521,9 @@ class TestDocumentServiceRecoverDocument:
document.dataset_id, document.id
)
def test_recover_document_not_paused_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_recover_document_not_paused_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when trying to recover non-paused document.
@ -590,7 +603,9 @@ class TestDocumentServiceRetryDocument:
"user_id": user_id,
}
def test_retry_document_single_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_retry_document_single_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful retry of single document.
@ -629,7 +644,9 @@ class TestDocumentServiceRetryDocument:
dataset.id, [document.id], mock_document_service_dependencies["user_id"]
)
def test_retry_document_multiple_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_retry_document_multiple_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful retry of multiple documents.
@ -675,7 +692,7 @@ class TestDocumentServiceRetryDocument:
)
def test_retry_document_concurrent_retry_error(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when document is already being retried.
@ -708,7 +725,7 @@ class TestDocumentServiceRetryDocument:
assert document.indexing_status == IndexingStatus.ERROR
def test_retry_document_missing_current_user_error(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when current_user is missing.
@ -794,7 +811,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
}
def test_batch_update_document_status_enable_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful batch enabling of documents.
@ -844,7 +861,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
assert mock_document_service_dependencies["add_task"].delay.call_count == 2
def test_batch_update_document_status_disable_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful batch disabling of documents.
@ -886,7 +903,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id)
def test_batch_update_document_status_archive_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful batch archiving of documents.
@ -928,7 +945,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id)
def test_batch_update_document_status_unarchive_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful batch unarchiving of documents.
@ -970,7 +987,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id)
def test_batch_update_document_status_empty_list(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test handling of empty document list.
@ -996,7 +1013,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
mock_document_service_dependencies["remove_task"].delay.assert_not_called()
def test_batch_update_document_status_document_indexing_error(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when document is being indexed.
@ -1073,7 +1090,7 @@ class TestDocumentServiceRenameDocument:
"current_user": mock_current_user,
}
def test_rename_document_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_success(self, db_session_with_containers: Session, mock_document_service_dependencies):
"""
Test successful document renaming.
@ -1111,7 +1128,9 @@ class TestDocumentServiceRenameDocument:
assert result == document
assert document.name == new_name
def test_rename_document_with_built_in_fields(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_with_built_in_fields(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test document renaming with built-in fields enabled.
@ -1154,7 +1173,9 @@ class TestDocumentServiceRenameDocument:
assert document.doc_metadata["document_name"] == new_name
assert document.doc_metadata["existing_key"] == "existing_value"
def test_rename_document_with_upload_file(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_with_upload_file(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test document renaming with associated upload file.
@ -1202,7 +1223,7 @@ class TestDocumentServiceRenameDocument:
assert upload_file.name == new_name
def test_rename_document_dataset_not_found_error(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when dataset is not found.
@ -1224,7 +1245,9 @@ class TestDocumentServiceRenameDocument:
with pytest.raises(ValueError, match="Dataset not found"):
DocumentService.rename_document(dataset_id, document_id, new_name)
def test_rename_document_not_found_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_not_found_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when document is not found.
@ -1251,7 +1274,9 @@ class TestDocumentServiceRenameDocument:
with pytest.raises(ValueError, match="Document not found"):
DocumentService.rename_document(dataset.id, document_id, new_name)
def test_rename_document_permission_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_permission_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when user lacks permission.

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import pytest
from redis import RedisError
from sqlalchemy.orm import Session
from extensions.ext_redis import redis_client
from models.account import TenantAccountJoin
@ -122,7 +123,7 @@ class TestSyncAccountDeletion:
mock_queue_task.assert_not_called()
def test_sync_account_deletion_multiple_workspaces(
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
):
account_id = str(uuid4())
tenant_ids = [str(uuid4()) for _ in range(3)]
@ -144,7 +145,7 @@ class TestSyncAccountDeletion:
assert queued_workspace_ids == set(tenant_ids)
def test_sync_account_deletion_no_workspaces(
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
):
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = True
@ -155,7 +156,7 @@ class TestSyncAccountDeletion:
mock_queue_task.assert_not_called()
def test_sync_account_deletion_partial_failure(
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
):
account_id = str(uuid4())
tenant_ids = [str(uuid4()) for _ in range(3)]
@ -180,7 +181,7 @@ class TestSyncAccountDeletion:
assert mock_queue_task.call_count == 3
def test_sync_account_deletion_all_failures(
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
):
account_id = str(uuid4())
tenant_id = str(uuid4())

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy.orm import Session
from models.model import App, RecommendedApp, Site
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
from services.recommend_app.recommend_app_type import RecommendAppType
@ -91,7 +93,7 @@ class TestDatabaseRecommendAppRetrieval:
class TestFetchRecommendedAppsFromDb:
def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers):
def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
_create_site(db_session_with_containers, app_id=app1.id)
@ -111,7 +113,9 @@ class TestFetchRecommendedAppsFromDb:
assert "assistant" in result["categories"]
assert "writing" in result["categories"]
def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers):
def test_falls_back_to_default_language_when_empty(
self, flask_app_with_containers, db_session_with_containers: Session
):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
_create_site(db_session_with_containers, app_id=app1.id)
@ -124,7 +128,7 @@ class TestFetchRecommendedAppsFromDb:
app_ids = {r["app_id"] for r in result["recommended_apps"]}
assert app1.id in app_ids
def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers):
def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False)
_create_site(db_session_with_containers, app_id=app1.id)
@ -137,7 +141,7 @@ class TestFetchRecommendedAppsFromDb:
app_ids = {r["app_id"] for r in result["recommended_apps"]}
assert app1.id not in app_ids
def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers):
def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
_create_recommended_app(db_session_with_containers, app_id=app1.id)
@ -151,12 +155,12 @@ class TestFetchRecommendedAppsFromDb:
class TestFetchRecommendedAppDetailFromDb:
def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers):
def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers: Session):
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4()))
assert result is None
def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers):
def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False)
_create_recommended_app(db_session_with_containers, app_id=app1.id)
@ -168,7 +172,7 @@ class TestFetchRecommendedAppDetailFromDb:
assert result is None
@patch("services.recommend_app.database.database_retrieval.AppDslService")
def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers):
def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
_create_site(db_session_with_containers, app_id=app1.id)

View File

@ -2,6 +2,7 @@ import copy
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
@ -29,7 +30,9 @@ class TestAdvancedPromptTemplateService:
# for consistency with other test files
return {}
def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_prompt_baichuan_model_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful prompt generation for Baichuan model.
@ -64,7 +67,9 @@ class TestAdvancedPromptTemplateService:
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_prompt_common_model_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful prompt generation for common models.
@ -100,7 +105,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#query#}}" in prompt_text
def test_get_prompt_case_insensitive_baichuan_detection(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan model detection is case insensitive.
@ -131,7 +136,7 @@ class TestAdvancedPromptTemplateService:
assert BAICHUAN_CONTEXT in prompt_text
def test_get_common_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation for chat app with completion mode.
@ -161,7 +166,9 @@ class TestAdvancedPromptTemplateService:
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_common_prompt_chat_app_chat_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation for chat app with chat mode.
@ -189,7 +196,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with completion mode.
@ -217,7 +224,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with chat mode.
@ -245,7 +252,9 @@ class TestAdvancedPromptTemplateService:
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_common_prompt_no_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation without context.
@ -273,7 +282,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported app mode.
@ -291,7 +300,7 @@ class TestAdvancedPromptTemplateService:
assert result == {}
def test_get_common_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported model mode.
@ -308,7 +317,9 @@ class TestAdvancedPromptTemplateService:
# Assert: Verify empty dict is returned
assert result == {}
def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_completion_prompt_with_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test completion prompt generation with context.
@ -339,7 +350,7 @@ class TestAdvancedPromptTemplateService:
assert result_text == CONTEXT + original_text
def test_get_completion_prompt_without_context(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test completion prompt generation without context.
@ -368,7 +379,9 @@ class TestAdvancedPromptTemplateService:
assert result_text == original_text
assert CONTEXT not in result_text
def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_chat_prompt_with_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test chat prompt generation with context.
@ -399,7 +412,9 @@ class TestAdvancedPromptTemplateService:
assert original_text in result_text
assert result_text == CONTEXT + original_text
def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_chat_prompt_without_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test chat prompt generation without context.
@ -429,7 +444,7 @@ class TestAdvancedPromptTemplateService:
assert CONTEXT not in result_text
def test_get_baichuan_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with completion mode.
@ -460,7 +475,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_chat_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with chat mode.
@ -489,7 +504,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with completion mode.
@ -517,7 +532,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with chat mode.
@ -545,7 +560,9 @@ class TestAdvancedPromptTemplateService:
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_baichuan_prompt_no_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation without context.
@ -573,7 +590,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported app mode.
@ -591,7 +608,7 @@ class TestAdvancedPromptTemplateService:
assert result == {}
def test_get_baichuan_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported model mode.
@ -609,7 +626,7 @@ class TestAdvancedPromptTemplateService:
assert result == {}
def test_get_prompt_all_app_modes_common_model(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with common model.
@ -641,7 +658,7 @@ class TestAdvancedPromptTemplateService:
assert result != {}
def test_get_prompt_all_app_modes_baichuan_model(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with Baichuan model.
@ -672,7 +689,7 @@ class TestAdvancedPromptTemplateService:
assert result is not None
assert result != {}
def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_prompt_edge_cases(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test prompt generation with edge cases.
@ -704,7 +721,7 @@ class TestAdvancedPromptTemplateService:
# Should either return a valid result or empty dict, but not crash
assert result is not None
def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
def test_template_immutability(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test that original templates are not modified.
@ -738,7 +755,9 @@ class TestAdvancedPromptTemplateService:
assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
def test_baichuan_template_immutability(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that original Baichuan templates are not modified.
@ -772,7 +791,9 @@ class TestAdvancedPromptTemplateService:
assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies):
def test_context_integration_consistency(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test consistency of context integration across different scenarios.
@ -828,7 +849,7 @@ class TestAdvancedPromptTemplateService:
assert prompt_text.startswith(CONTEXT)
def test_baichuan_context_integration_consistency(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test consistency of Baichuan context integration across different scenarios.

View File

@ -10,6 +10,8 @@ from uuid import uuid4
import pytest
import yaml
from faker import Faker
from flask import Flask
from sqlalchemy.orm import Session
from core.trigger.constants import (
TRIGGER_PLUGIN_NODE_TYPE,
@ -88,7 +90,7 @@ class TestAppDslService:
"""Integration tests for AppDslService using testcontainers."""
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -129,7 +131,7 @@ class TestAppDslService:
"enterprise_service": mock_enterprise_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
fake = Faker()
with patch("services.account_service.FeatureService") as mock_account_feature_service:
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
@ -206,7 +208,7 @@ class TestAppDslService:
# ── Import: Validation ────────────────────────────────────────────
def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers):
def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Invalid import_mode"):
service.import_app(
@ -215,7 +217,7 @@ class TestAppDslService:
yaml_content="version: '0.1.0'",
)
def test_import_app_missing_yaml_content(self, db_session_with_containers):
def test_import_app_missing_yaml_content(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -225,7 +227,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "yaml_content is required" in result.error
def test_import_app_missing_yaml_url(self, db_session_with_containers):
def test_import_app_missing_yaml_url(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -235,7 +237,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "yaml_url is required" in result.error
def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers):
def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -245,7 +247,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "content must be a mapping" in result.error
def test_import_app_version_not_str_returns_failed(self, db_session_with_containers):
def test_import_app_version_not_str_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}})
result = service.import_app(
@ -256,7 +258,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Invalid version type" in result.error
def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers):
def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -266,7 +268,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Missing app data" in result.error
def test_import_app_yaml_error_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
def bad_safe_load(_content: str):
raise yaml.YAMLError("bad")
@ -281,7 +283,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert result.error.startswith("Invalid YAML format:")
def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
monkeypatch.setattr(
AppDslService,
"_create_or_update_app",
@ -299,7 +301,7 @@ class TestAppDslService:
# ── Import: YAML URL ──────────────────────────────────────────────
def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
monkeypatch.setattr(
app_dsl_service.ssrf_proxy,
"get",
@ -315,7 +317,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Error fetching YAML from URL: boom" in result.error
def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch):
response = MagicMock()
response.content = b""
response.raise_for_status.return_value = None
@ -330,7 +332,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Empty content" in result.error
def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch):
response = MagicMock()
response.content = b"x" * (DSL_MAX_SIZE + 1)
response.raise_for_status.return_value = None
@ -345,7 +347,9 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "File size exceeds" in result.error
def test_import_app_yaml_url_user_attachments_keeps_original_url(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_user_attachments_keeps_original_url(
self, db_session_with_containers: Session, monkeypatch
):
yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml"
yaml_bytes = _pending_yaml_content()
@ -371,7 +375,7 @@ class TestAppDslService:
assert result.imported_dsl_version == "99.0.0"
assert requested_urls == [yaml_url]
def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch):
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
yaml_bytes = _pending_yaml_content()
@ -400,7 +404,7 @@ class TestAppDslService:
# ── Import: App ID checks ────────────────────────────────────────
def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers):
def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -412,7 +416,7 @@ class TestAppDslService:
assert result.error == "App not found"
def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
assert app.mode == "chat"
@ -429,7 +433,7 @@ class TestAppDslService:
# ── Import: Flow ──────────────────────────────────────────────────
def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers):
def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -449,7 +453,7 @@ class TestAppDslService:
assert stored is not None
def test_import_app_completed_uses_declared_dependencies(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
_, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
@ -483,7 +487,7 @@ class TestAppDslService:
@pytest.mark.parametrize("has_workflow", [True, False])
def test_import_app_legacy_versions_extract_dependencies(
self, db_session_with_containers, monkeypatch, has_workflow: bool
self, db_session_with_containers: Session, monkeypatch, has_workflow: bool
):
monkeypatch.setattr(
AppDslService,
@ -540,13 +544,13 @@ class TestAppDslService:
# ── Confirm Import ────────────────────────────────────────────────
def test_confirm_import_expired_returns_failed(self, db_session_with_containers):
def test_confirm_import_expired_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.confirm_import(import_id=str(uuid4()), account=_account_mock())
assert result.status == ImportStatus.FAILED
assert "expired" in result.error
def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers, monkeypatch):
def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch):
import_id = str(uuid4())
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
@ -579,7 +583,7 @@ class TestAppDslService:
assert result.app_id == created_app.id
assert redis_client.get(redis_key) is None
def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers):
def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers: Session):
import_id = str(uuid4())
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123")
@ -589,7 +593,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "validation error" in result.error
def test_confirm_import_exception_returns_failed(self, db_session_with_containers):
def test_confirm_import_exception_returns_failed(self, db_session_with_containers: Session):
import_id = str(uuid4())
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json")
@ -600,13 +604,13 @@ class TestAppDslService:
# ── Check Dependencies ────────────────────────────────────────────
def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers):
def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
app_model = _app_stub()
result = service.check_dependencies(app_model=app_model)
assert result.leaked_dependencies == []
def test_check_dependencies_calls_analysis_service(self, db_session_with_containers, monkeypatch):
def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch):
app_id = str(uuid4())
pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id)
redis_client.setex(
@ -634,7 +638,9 @@ class TestAppDslService:
result = service.check_dependencies(app_model=_app_stub(id=app_id))
assert len(result.leaked_dependencies) == 1
def test_check_dependencies_with_real_app(self, db_session_with_containers, mock_external_service_dependencies):
def test_check_dependencies_with_real_app(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}'
@ -650,12 +656,12 @@ class TestAppDslService:
# ── Create/Update App ─────────────────────────────────────────────
def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers):
def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="loss app mode"):
service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock())
def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers, monkeypatch):
def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch):
fixed_now = object()
monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now)
@ -707,7 +713,7 @@ class TestAppDslService:
assert app.icon_background == "#222222"
assert app.updated_at is fixed_now
def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers):
def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers: Session):
account = _account_mock()
account.current_tenant_id = None
service = AppDslService(db_session_with_containers)
@ -719,7 +725,7 @@ class TestAppDslService:
)
def test_create_or_update_app_creates_workflow_app_and_saves_dependencies(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
_, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
@ -755,7 +761,7 @@ class TestAppDslService:
stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}")
assert stored is not None
def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers):
def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Missing workflow data"):
service._create_or_update_app(
@ -764,7 +770,7 @@ class TestAppDslService:
account=_account_mock(),
)
def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers):
def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Missing model_config"):
service._create_or_update_app(
@ -774,7 +780,7 @@ class TestAppDslService:
)
def test_create_or_update_app_chat_creates_model_config_and_sends_event(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
app.app_model_config_id = None
@ -793,7 +799,7 @@ class TestAppDslService:
db_session_with_containers.expire_all()
assert app.app_model_config_id is not None
def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers):
def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Invalid app mode"):
service._create_or_update_app(
@ -870,7 +876,7 @@ class TestAppDslService:
assert data["app"]["icon_type"] == "image"
assert data["app"]["icon_background"] == "#FFEAD5"
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_export_dsl_chat_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
model_config = AppModelConfig(
@ -908,7 +914,9 @@ class TestAppDslService:
assert "model_config" in exported_data
assert "dependencies" in exported_data
def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_export_dsl_workflow_app_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
app.mode = "workflow"
db_session_with_containers.commit()
@ -941,7 +949,9 @@ class TestAppDslService:
assert "workflow" in exported_data
assert "dependencies" in exported_data
def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_export_dsl_with_workflow_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
app.mode = "workflow"
db_session_with_containers.commit()
@ -981,7 +991,7 @@ class TestAppDslService:
assert "workflow" in exported_data
def test_export_dsl_with_invalid_workflow_id_raises_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
app.mode = "workflow"

View File

@ -7,7 +7,7 @@ from uuid import uuid4
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import NotFound
import services.attachment_service as attachment_service_module
@ -19,7 +19,7 @@ from services.attachment_service import AttachmentService
class TestAttachmentService:
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
def _create_upload_file(self, db_session_with_containers: Session, *, tenant_id: str | None = None) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id or str(uuid4()),
storage_type=StorageType.OPENDAL,
@ -60,7 +60,7 @@ class TestAttachmentService:
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
AttachmentService(session_factory=invalid_session_factory)
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
def test_should_return_base64_when_file_exists(self, db_session_with_containers: Session):
upload_file = self._create_upload_file(db_session_with_containers)
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
@ -70,7 +70,7 @@ class TestAttachmentService:
assert result == base64.b64encode(b"binary-content").decode()
mock_load.assert_called_once_with(upload_file.key)
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers: Session):
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once") as mock_load:

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from extensions.ext_redis import redis_client
@ -24,7 +25,7 @@ class TestBillingServiceGetPlanBulkWithCache:
"""
@pytest.fixture(autouse=True)
def setup_redis_cleanup(self, flask_app_with_containers):
def setup_redis_cleanup(self, flask_app_with_containers: Flask):
"""Clean up Redis cache before and after each test."""
with flask_app_with_containers.app_context():
# Clean up before test
@ -56,7 +57,7 @@ class TestBillingServiceGetPlanBulkWithCache:
return value
return None
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers: Flask):
"""Test bulk plan retrieval when all tenants are in cache."""
with flask_app_with_containers.app_context():
# Arrange
@ -87,7 +88,7 @@ class TestBillingServiceGetPlanBulkWithCache:
# Verify API was not called
mock_get_plan_bulk.assert_not_called()
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers: Flask):
"""Test bulk plan retrieval when all tenants are not in cache."""
with flask_app_with_containers.app_context():
# Arrange
@ -127,7 +128,7 @@ class TestBillingServiceGetPlanBulkWithCache:
assert ttl_1 > 0
assert ttl_1 <= 600 # Should be <= 600 seconds
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers: Flask):
"""Test bulk plan retrieval when some tenants are in cache, some are not."""
with flask_app_with_containers.app_context():
# Arrange
@ -158,7 +159,7 @@ class TestBillingServiceGetPlanBulkWithCache:
cached_data_3 = json.loads(cached_3)
assert cached_data_3 == missing_plan["tenant-3"]
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers: Flask):
"""Test fallback to API when Redis mget fails."""
with flask_app_with_containers.app_context():
# Arrange
@ -189,7 +190,7 @@ class TestBillingServiceGetPlanBulkWithCache:
assert cached_1 is not None
assert cached_2 is not None
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers: Flask):
"""Test fallback to API when cache contains invalid JSON."""
with flask_app_with_containers.app_context():
# Arrange
@ -241,7 +242,7 @@ class TestBillingServiceGetPlanBulkWithCache:
cached_data_3 = json.loads(cached_3)
assert cached_data_3 == expected_plans["tenant-3"]
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers: Flask):
"""Test fallback to API when cache data doesn't match SubscriptionPlan schema."""
with flask_app_with_containers.app_context():
# Arrange
@ -274,7 +275,7 @@ class TestBillingServiceGetPlanBulkWithCache:
# Verify API was called for tenant-2 and tenant-3
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers: Flask):
"""Test that pipeline failure doesn't affect return value."""
with flask_app_with_containers.app_context():
# Arrange
@ -303,7 +304,7 @@ class TestBillingServiceGetPlanBulkWithCache:
# Verify pipeline was attempted
mock_pipeline.assert_called_once()
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers: Flask):
"""Test with empty tenant_ids list."""
with flask_app_with_containers.app_context():
# Act
@ -321,7 +322,7 @@ class TestBillingServiceGetPlanBulkWithCache:
# But we should check that mget was not called at all
# Since we can't easily verify this without more mocking, we just verify the result
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers: Flask):
"""Test that expired cache keys are treated as cache misses."""
with flask_app_with_containers.app_context():
# Arrange

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account, Tenant, TenantAccountJoin
@ -170,7 +171,7 @@ class ConversationServiceIntegrationTestDataFactory:
class TestConversationServicePagination:
"""Test conversation pagination operations."""
def test_pagination_with_non_empty_include_ids(self, db_session_with_containers):
def test_pagination_with_non_empty_include_ids(self, db_session_with_containers: Session):
"""
Test that non-empty include_ids filters properly.
@ -204,7 +205,7 @@ class TestConversationServicePagination:
returned_ids = {conversation.id for conversation in result.data}
assert returned_ids == {conversations[0].id, conversations[1].id}
def test_pagination_with_empty_exclude_ids(self, db_session_with_containers):
def test_pagination_with_empty_exclude_ids(self, db_session_with_containers: Session):
"""
Test that empty exclude_ids doesn't filter.
@ -237,7 +238,7 @@ class TestConversationServicePagination:
# Assert
assert len(result.data) == len(conversations)
def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers):
def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers: Session):
"""
Test that non-empty exclude_ids filters properly.
@ -271,7 +272,7 @@ class TestConversationServicePagination:
returned_ids = {conversation.id for conversation in result.data}
assert returned_ids == {conversations[2].id}
def test_pagination_with_sorting_descending(self, db_session_with_containers):
def test_pagination_with_sorting_descending(self, db_session_with_containers: Session):
"""
Test pagination with descending sort order.
@ -316,7 +317,7 @@ class TestConversationServiceMessageCreation:
within conversations.
"""
def test_pagination_by_first_id_without_first_id(self, db_session_with_containers):
def test_pagination_by_first_id_without_first_id(self, db_session_with_containers: Session):
"""
Test message pagination without specifying first_id.
@ -354,7 +355,7 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == 3 # All 3 messages returned
assert result.has_more is False # No more messages available (3 < limit of 10)
def test_pagination_by_first_id_with_first_id(self, db_session_with_containers):
def test_pagination_by_first_id_with_first_id(self, db_session_with_containers: Session):
"""
Test message pagination with first_id specified.
@ -399,7 +400,9 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == 2 # Only 2 messages returned after first_id
assert result.has_more is False # No more messages available (2 < limit of 10)
def test_pagination_by_first_id_raises_error_when_first_message_not_found(self, db_session_with_containers):
def test_pagination_by_first_id_raises_error_when_first_message_not_found(
self, db_session_with_containers: Session
):
"""
Test that FirstMessageNotExistsError is raised when first_id doesn't exist.
@ -424,7 +427,7 @@ class TestConversationServiceMessageCreation:
limit=10,
)
def test_pagination_with_has_more_flag(self, db_session_with_containers):
def test_pagination_with_has_more_flag(self, db_session_with_containers: Session):
"""
Test that has_more flag is correctly set when there are more messages.
@ -463,7 +466,7 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == limit # Extra message should be removed
assert result.has_more is True # Flag should be set
def test_pagination_with_ascending_order(self, db_session_with_containers):
def test_pagination_with_ascending_order(self, db_session_with_containers: Session):
"""
Test message pagination with ascending order.
@ -512,7 +515,7 @@ class TestConversationServiceSummarization:
"""
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers):
def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers: Session):
"""
Test successful auto-generation of conversation name.
@ -552,7 +555,7 @@ class TestConversationServiceSummarization:
app_model.tenant_id, first_message.query, conversation.id, app_model.id
)
def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers):
def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers: Session):
"""
Test that MessageNotExistsError is raised when conversation has no messages.
@ -571,7 +574,9 @@ class TestConversationServiceSummarization:
ConversationService.auto_generate_name(app_model, conversation)
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_llm_generator, db_session_with_containers):
def test_auto_generate_name_handles_llm_failure_gracefully(
self, mock_llm_generator, db_session_with_containers: Session
):
"""
Test that LLM generation failures are suppressed and don't crash.
@ -604,7 +609,7 @@ class TestConversationServiceSummarization:
assert conversation.name == original_name # Name remains unchanged
@patch("services.conversation_service.naive_utc_now")
def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers):
def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers: Session):
"""
Test renaming conversation with manual name.
@ -638,7 +643,7 @@ class TestConversationServiceSummarization:
assert conversation.updated_at == mock_time
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers):
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers: Session):
"""
Test rename delegates to auto_generate_name when auto_generate is True.
@ -682,7 +687,9 @@ class TestConversationServiceMessageAnnotation:
@patch("services.annotation_service.add_annotation_to_index_task")
@patch("services.annotation_service.current_account_with_tenant")
def test_create_annotation_from_message(self, mock_current_account, mock_add_task, db_session_with_containers):
def test_create_annotation_from_message(
self, mock_current_account, mock_add_task, db_session_with_containers: Session
):
"""
Test creating annotation from existing message.
@ -721,7 +728,9 @@ class TestConversationServiceMessageAnnotation:
@patch("services.annotation_service.add_annotation_to_index_task")
@patch("services.annotation_service.current_account_with_tenant")
def test_create_annotation_without_message(self, mock_current_account, mock_add_task, db_session_with_containers):
def test_create_annotation_without_message(
self, mock_current_account, mock_add_task, db_session_with_containers: Session
):
"""
Test creating standalone annotation without message.
@ -753,7 +762,7 @@ class TestConversationServiceMessageAnnotation:
@patch("services.annotation_service.add_annotation_to_index_task")
@patch("services.annotation_service.current_account_with_tenant")
def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers):
def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers: Session):
"""
Test updating an existing annotation.
@ -800,7 +809,7 @@ class TestConversationServiceMessageAnnotation:
mock_add_task.delay.assert_not_called()
@patch("services.annotation_service.current_account_with_tenant")
def test_get_annotation_list(self, mock_current_account, db_session_with_containers):
def test_get_annotation_list(self, mock_current_account, db_session_with_containers: Session):
"""
Test retrieving paginated annotation list.
@ -836,7 +845,7 @@ class TestConversationServiceMessageAnnotation:
assert result_total == 5
@patch("services.annotation_service.current_account_with_tenant")
def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers):
def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers: Session):
"""
Test retrieving annotations with keyword filtering.
@ -885,7 +894,7 @@ class TestConversationServiceMessageAnnotation:
@patch("services.annotation_service.add_annotation_to_index_task")
@patch("services.annotation_service.current_account_with_tenant")
def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers):
def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers: Session):
"""
Test direct annotation insertion without message reference.
@ -919,7 +928,7 @@ class TestConversationServiceExport:
Tests retrieving conversation data for export purposes.
"""
def test_get_conversation_success(self, db_session_with_containers):
def test_get_conversation_success(self, db_session_with_containers: Session):
"""Test successful retrieval of conversation."""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -937,7 +946,7 @@ class TestConversationServiceExport:
# Assert
assert result == conversation
def test_get_conversation_not_found(self, db_session_with_containers):
def test_get_conversation_not_found(self, db_session_with_containers: Session):
"""Test ConversationNotExistsError when conversation doesn't exist."""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -949,7 +958,7 @@ class TestConversationServiceExport:
ConversationService.get_conversation(app_model=app_model, conversation_id=str(uuid4()), user=user)
@patch("services.annotation_service.current_account_with_tenant")
def test_export_annotation_list(self, mock_current_account, db_session_with_containers):
def test_export_annotation_list(self, mock_current_account, db_session_with_containers: Session):
"""Test exporting all annotations for an app."""
# Arrange
app_model, account = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -977,7 +986,7 @@ class TestConversationServiceExport:
# Assert
assert len(result) == 10
def test_get_message_success(self, db_session_with_containers):
def test_get_message_success(self, db_session_with_containers: Session):
"""Test successful retrieval of a message."""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -1001,7 +1010,7 @@ class TestConversationServiceExport:
# Assert
assert result == message
def test_get_message_not_found(self, db_session_with_containers):
def test_get_message_not_found(self, db_session_with_containers: Session):
"""Test MessageNotExistsError when message doesn't exist."""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -1012,7 +1021,7 @@ class TestConversationServiceExport:
with pytest.raises(MessageNotExistsError):
MessageService.get_message(app_model=app_model, user=user, message_id=str(uuid4()))
def test_get_conversation_for_end_user(self, db_session_with_containers):
def test_get_conversation_for_end_user(self, db_session_with_containers: Session):
"""
Test retrieving conversation created by end user via API.
@ -1038,7 +1047,7 @@ class TestConversationServiceExport:
assert result == conversation
@patch("services.conversation_service.delete_conversation_related_data")
def test_delete_conversation(self, mock_delete_task, db_session_with_containers):
def test_delete_conversation(self, mock_delete_task, db_session_with_containers: Session):
"""
Test conversation deletion with async cleanup.
@ -1071,7 +1080,7 @@ class TestConversationServiceExport:
mock_delete_task.delay.assert_called_once_with(conversation_id)
@patch("services.conversation_service.delete_conversation_related_data")
def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers):
def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers: Session):
"""
Test deletion is denied when conversation belongs to a different account.
"""
@ -1102,7 +1111,7 @@ class TestConversationServiceExport:
mock_delete_task.delay.assert_not_called()
@patch("services.conversation_service.delete_conversation_related_data")
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers):
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers: Session):
"""
Test that delete propagates exceptions and does not trigger the cleanup task.

View File

@ -5,7 +5,8 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import sessionmaker
from flask import Flask
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
@ -149,7 +150,7 @@ class ConversationServiceVariableIntegrationFactory:
@pytest.fixture
def real_conversation_service_session_factory(flask_app_with_containers):
def real_conversation_service_session_factory(flask_app_with_containers: Flask):
del flask_app_with_containers
real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
@ -162,7 +163,7 @@ def real_conversation_service_session_factory(flask_app_with_containers):
class TestConversationServiceVariables:
def test_get_conversational_variable_success(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -200,7 +201,7 @@ class TestConversationServiceVariables:
assert result.has_more is False
def test_get_conversational_variable_with_last_id(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -242,7 +243,7 @@ class TestConversationServiceVariables:
assert result.has_more is False
def test_get_conversational_variable_last_id_not_found_raises_error(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -259,7 +260,7 @@ class TestConversationServiceVariables:
)
def test_get_conversational_variable_sets_has_more(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -287,7 +288,7 @@ class TestConversationServiceVariables:
assert result.has_more is True
def test_update_conversation_variable_success(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -320,7 +321,7 @@ class TestConversationServiceVariables:
assert result["updated_at"] == updated_at
def test_update_conversation_variable_not_found_raises_error(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -337,7 +338,7 @@ class TestConversationServiceVariables:
)
def test_update_conversation_variable_type_mismatch_raises_error(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -360,7 +361,7 @@ class TestConversationServiceVariables:
)
def test_update_conversation_variable_integer_number_compatibility(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -390,7 +391,7 @@ class TestConversationServiceVariables:
class TestConversationServicePaginationWithContainers:
def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers):
def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
@ -404,7 +405,7 @@ class TestConversationServicePaginationWithContainers:
invoke_from=InvokeFrom.WEB_APP,
)
def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers):
def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
base_time = datetime(2024, 1, 1, 8, 0, 0)
@ -442,7 +443,7 @@ class TestConversationServicePaginationWithContainers:
assert newest.id != middle.id
assert [conversation.id for conversation in result.data] == [oldest.id]
def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers):
def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha")
@ -462,7 +463,7 @@ class TestConversationServicePaginationWithContainers:
assert alpha.id != beta.id
assert [conversation.id for conversation in result.data] == [gamma.id]
def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers):
def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
end_user = factory.create_end_user(db_session_with_containers, app)
@ -493,7 +494,7 @@ class TestConversationServicePaginationWithContainers:
assert account_conversation.id != end_user_conversation.id
assert [conversation.id for conversation in result.data] == [end_user_conversation.id]
def test_pagination_filters_to_account_console_source(self, db_session_with_containers):
def test_pagination_filters_to_account_console_source(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
end_user = factory.create_end_user(db_session_with_containers, app)

View File

@ -3,7 +3,7 @@
from uuid import uuid4
import pytest
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from extensions.ext_database import db
from graphon.variables import StringVariable
@ -13,7 +13,12 @@ from services.conversation_variable_updater import ConversationVariableNotFoundE
class TestConversationVariableUpdater:
def _create_conversation_variable(
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
self,
db_session_with_containers: Session,
*,
conversation_id: str,
variable: StringVariable,
app_id: str | None = None,
) -> ConversationVariable:
row = ConversationVariable(
id=variable.id,
@ -25,7 +30,7 @@ class TestConversationVariableUpdater:
db_session_with_containers.commit()
return row
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers: Session):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
self._create_conversation_variable(
@ -42,7 +47,7 @@ class TestConversationVariableUpdater:
assert row is not None
assert row.data == updated_variable.model_dump_json()
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers: Session):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
@ -50,7 +55,7 @@ class TestConversationVariableUpdater:
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
updater.update(conversation_id=conversation_id, variable=variable)
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers: Session):
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
result = updater.flush()

View File

@ -3,6 +3,7 @@
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
@ -14,7 +15,7 @@ class TestCreditPoolService:
def _create_tenant_id(self) -> str:
return str(uuid4())
def test_create_default_pool(self, db_session_with_containers):
def test_create_default_pool(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
@ -25,7 +26,7 @@ class TestCreditPoolService:
assert pool.quota_used == 0
assert pool.quota_limit > 0
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers):
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
@ -35,17 +36,17 @@ class TestCreditPoolService:
assert result.tenant_id == tenant_id
assert result.pool_type == ProviderQuotaType.TRIAL
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session):
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
assert result is None
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers):
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session):
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
assert result is False
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers):
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
@ -53,7 +54,7 @@ class TestCreditPoolService:
assert result is True
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers):
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
# Exhaust credits
@ -64,11 +65,11 @@ class TestCreditPoolService:
assert result is False
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers):
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session):
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers):
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
pool.quota_used = pool.quota_limit
@ -77,7 +78,7 @@ class TestCreditPoolService:
with pytest.raises(QuotaExceededError, match="No credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers):
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
credits_required = 10
@ -89,7 +90,7 @@ class TestCreditPoolService:
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers):
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5

View File

@ -8,6 +8,7 @@ checks with testcontainers-backed infrastructure instead of database-chain mocks
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
@ -107,7 +108,7 @@ class DatasetPermissionTestDataFactory:
class TestDatasetPermissionServiceGetPartialMemberList:
"""Verify partial-member list reads against persisted DatasetPermission rows."""
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers):
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers: Session):
"""
Test retrieving partial member list with multiple members.
"""
@ -138,7 +139,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
assert set(result) == set(expected_account_ids)
assert len(result) == 3
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers):
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers: Session):
"""
Test retrieving partial member list with single member.
"""
@ -160,7 +161,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
assert set(result) == set(expected_account_ids)
assert len(result) == 1
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers):
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers: Session):
"""
Test retrieving partial member list when no members exist.
"""
@ -179,7 +180,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
class TestDatasetPermissionServiceUpdatePartialMemberList:
"""Verify partial-member list updates against persisted DatasetPermission rows."""
def test_update_partial_member_list_add_new_members(self, db_session_with_containers):
def test_update_partial_member_list_add_new_members(self, db_session_with_containers: Session):
"""
Test adding new partial members to a dataset.
"""
@ -203,7 +204,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert set(result) == {member_1.id, member_2.id}
def test_update_partial_member_list_replace_existing(self, db_session_with_containers):
def test_update_partial_member_list_replace_existing(self, db_session_with_containers: Session):
"""
Test replacing existing partial members with new ones.
"""
@ -239,7 +240,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert set(result) == {new_member_1.id, new_member_2.id}
def test_update_partial_member_list_empty_list(self, db_session_with_containers):
def test_update_partial_member_list_empty_list(self, db_session_with_containers: Session):
"""
Test updating with empty member list (clearing all members).
"""
@ -264,7 +265,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers):
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers: Session):
"""
Test error handling and rollback on database error.
"""
@ -313,7 +314,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
class TestDatasetPermissionServiceClearPartialMemberList:
"""Verify partial-member clearing against persisted DatasetPermission rows."""
def test_clear_partial_member_list_success(self, db_session_with_containers):
def test_clear_partial_member_list_success(self, db_session_with_containers: Session):
"""
Test successful clearing of partial member list.
"""
@ -338,7 +339,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_clear_partial_member_list_empty_list(self, db_session_with_containers):
def test_clear_partial_member_list_empty_list(self, db_session_with_containers: Session):
"""
Test clearing partial member list when no members exist.
"""
@ -353,7 +354,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers):
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers: Session):
"""
Test error handling and rollback on database error.
"""
@ -398,7 +399,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
class TestDatasetServiceCheckDatasetPermission:
"""Verify dataset access checks against persisted partial-member permissions."""
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers):
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers: Session):
"""Test that users from different tenants cannot access dataset."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
@ -410,7 +411,7 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other_user)
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers):
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers: Session):
"""Test that tenant owners can access any dataset regardless of permission level."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
@ -423,7 +424,7 @@ class TestDatasetServiceCheckDatasetPermission:
DatasetService.check_dataset_permission(dataset, owner)
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers):
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers: Session):
"""Test ONLY_ME permission allows only the dataset creator to access."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
@ -433,7 +434,7 @@ class TestDatasetServiceCheckDatasetPermission:
DatasetService.check_dataset_permission(dataset, creator)
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers):
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers: Session):
"""Test ONLY_ME permission denies access to non-creators."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other)
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers):
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers: Session):
"""Test ALL_TEAM permission allows any team member to access the dataset."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
@ -460,7 +461,9 @@ class TestDatasetServiceCheckDatasetPermission:
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
def test_check_dataset_permission_partial_members_with_permission_success(
self, db_session_with_containers: Session
):
"""
Test that user with explicit permission can access partial_members dataset.
"""
@ -485,7 +488,9 @@ class TestDatasetServiceCheckDatasetPermission:
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert user.id in permissions
def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers):
def test_check_dataset_permission_partial_members_without_permission_error(
self, db_session_with_containers: Session
):
"""
Test error when user without permission tries to access partial_members dataset.
"""
@ -506,7 +511,7 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers):
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers: Session):
"""Test PARTIAL_TEAM permission allows creator to access without explicit permission."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)

View File

@ -712,7 +712,7 @@ class TestDatasetServiceRetrievalConfiguration:
class TestDocumentServicePauseRecoverRetry:
"""Tests for pause/recover/retry orchestration using real DB and Redis."""
def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"):
def _create_indexing_document(self, db_session_with_containers: Session, indexing_status="indexing"):
factory = DatasetServiceIntegrationDataFactory
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
@ -721,7 +721,7 @@ class TestDocumentServicePauseRecoverRetry:
db_session_with_containers.commit()
return doc, account
def test_pause_document_success(self, db_session_with_containers):
def test_pause_document_success(self, db_session_with_containers: Session):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService
@ -740,7 +740,7 @@ class TestDocumentServicePauseRecoverRetry:
assert redis_client.get(cache_key) is not None
redis_client.delete(cache_key)
def test_pause_document_invalid_status_error(self, db_session_with_containers):
def test_pause_document_invalid_status_error(self, db_session_with_containers: Session):
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -751,7 +751,7 @@ class TestDocumentServicePauseRecoverRetry:
with pytest.raises(DocumentIndexingError):
DocumentService.pause_document(doc)
def test_recover_document_success(self, db_session_with_containers):
def test_recover_document_success(self, db_session_with_containers: Session):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService
@ -775,7 +775,7 @@ class TestDocumentServicePauseRecoverRetry:
assert redis_client.get(cache_key) is None
recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id)
def test_retry_document_indexing_success(self, db_session_with_containers):
def test_retry_document_indexing_success(self, db_session_with_containers: Session):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService

View File

@ -6,6 +6,7 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin
from services.dataset_service import DatasetService
@ -48,7 +49,7 @@ class TestDatasetServiceCreateRagPipelineDataset:
permission="only_me",
)
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers):
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers: Session):
tenant, _ = self._create_tenant_and_account(db_session_with_containers)
mock_user = Mock(id=None)

View File

@ -3,6 +3,8 @@
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
@ -101,7 +103,7 @@ class DatasetDeleteIntegrationDataFactory:
class TestDatasetServiceDeleteDataset:
"""Integration coverage for DatasetService.delete_dataset using testcontainers."""
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session):
"""Delete a dataset with documents and dispatch cleanup through the real signal handler."""
# Arrange
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
@ -144,7 +146,7 @@ class TestDatasetServiceDeleteDataset:
dataset.pipeline_id,
)
def test_delete_empty_dataset_success(self, db_session_with_containers):
def test_delete_empty_dataset_success(self, db_session_with_containers: Session):
"""Delete an empty dataset without scheduling cleanup when both gating fields are absent."""
# Arrange
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
@ -172,7 +174,7 @@ class TestDatasetServiceDeleteDataset:
assert db_session_with_containers.get(Dataset, dataset.id) is None
clean_dataset_delay.assert_not_called()
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session):
"""Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves."""
# Arrange
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
@ -200,7 +202,7 @@ class TestDatasetServiceDeleteDataset:
assert db_session_with_containers.get(Dataset, dataset.id) is None
clean_dataset_delay.assert_not_called()
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers):
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers: Session):
"""Delete a dataset without cleanup when indexing exists but doc_form resolves to None."""
# Arrange
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
@ -228,7 +230,7 @@ class TestDatasetServiceDeleteDataset:
assert db_session_with_containers.get(Dataset, dataset.id) is None
clean_dataset_delay.assert_not_called()
def test_delete_dataset_not_found(self, db_session_with_containers):
def test_delete_dataset_not_found(self, db_session_with_containers: Session):
"""Return False without scheduling cleanup when the target dataset does not exist."""
# Arrange
owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -363,7 +364,7 @@ class TestDatasetServicePermissionsAndLifecycle:
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers):
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers: Flask):
with flask_app_with_containers.app_context():
with pytest.raises(NotFound, match="Dataset not found"):
DatasetService.update_dataset_api_status(str(uuid4()), True)
@ -473,7 +474,7 @@ class TestDatasetCollectionBindingServiceIntegration:
assert persisted.type == "dataset"
assert persisted.collection_name
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers):
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers: Flask):
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4()))

View File

@ -6,6 +6,7 @@ from datetime import UTC, datetime, timedelta
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from graphon.enums import WorkflowExecutionStatus
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
@ -46,7 +47,7 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.commit()
return run
def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None:
def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> None:
archive_log = WorkflowArchiveLog(
tenant_id=run.tenant_id,
app_id=run.app_id,
@ -72,7 +73,7 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.add(archive_log)
db_session_with_containers.commit()
def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers):
def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers: Session):
deleter = ArchivedWorkflowRunDeletion()
missing_run_id = str(uuid4())
@ -81,7 +82,7 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is False
assert result.error == f"Workflow run {missing_run_id} not found"
def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers):
def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers: Session):
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
@ -95,7 +96,7 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is False
assert result.error == f"Workflow run {run.id} is not archived"
def test_delete_batch_uses_repo(self, db_session_with_containers):
def test_delete_batch_uses_repo(self, db_session_with_containers: Session):
tenant_id = str(uuid4())
base_time = datetime.now(UTC)
run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time)
@ -124,7 +125,7 @@ class TestArchivedWorkflowRunDeletion:
).all()
assert remaining_runs == []
def test_delete_run_calls_repo(self, db_session_with_containers):
def test_delete_run_calls_repo(self, db_session_with_containers: Session):
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
@ -142,7 +143,7 @@ class TestArchivedWorkflowRunDeletion:
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
assert deleted_run is None
def test_delete_run_dry_run(self, db_session_with_containers):
def test_delete_run_dry_run(self, db_session_with_containers: Session):
"""Dry run should return success without actually deleting."""
tenant_id = str(uuid4())
run = self._create_workflow_run(
@ -161,7 +162,7 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.expire_all()
assert db_session_with_containers.get(WorkflowRun, run_id) is not None
def test_delete_run_exception_returns_error(self, db_session_with_containers):
def test_delete_run_exception_returns_error(self, db_session_with_containers: Session):
"""Exception during deletion should return failure result."""
from unittest.mock import MagicMock, patch
@ -183,7 +184,7 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is False
assert result.error == "Database error"
def test_delete_by_run_id_success(self, db_session_with_containers):
def test_delete_by_run_id_success(self, db_session_with_containers: Session):
"""Successfully delete an archived workflow run by ID."""
tenant_id = str(uuid4())
base_time = datetime.now(UTC)
@ -202,7 +203,7 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.expunge_all()
assert db_session_with_containers.get(WorkflowRun, run_id) is None
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers):
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers: Session):
"""_get_workflow_run_repo should return a cached repo on subsequent calls."""
deleter = ArchivedWorkflowRunDeletion()

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account, Tenant, TenantAccountJoin
@ -102,7 +103,7 @@ class TestEndUserServiceGetOrCreateEndUser:
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory):
def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers: Session, factory):
"""Test getting or creating end user with custom user_id."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -118,7 +119,7 @@ class TestEndUserServiceGetOrCreateEndUser:
assert result.type == InvokeFrom.SERVICE_API
assert result.is_anonymous is False
def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory):
def test_get_or_create_end_user_without_user_id(self, db_session_with_containers: Session, factory):
"""Test getting or creating end user without user_id uses default session."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -131,7 +132,7 @@ class TestEndUserServiceGetOrCreateEndUser:
# Verify _is_anonymous is set correctly (property always returns False)
assert result._is_anonymous is True
def test_get_existing_end_user(self, db_session_with_containers, factory):
def test_get_existing_end_user(self, db_session_with_containers: Session, factory):
"""Test retrieving an existing end user."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -167,7 +168,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_create_end_user_service_api_type(self, db_session_with_containers, factory):
def test_create_end_user_service_api_type(self, db_session_with_containers: Session, factory):
"""Test creating new end user with SERVICE_API type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -189,7 +190,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result.app_id == app_id
assert result.session_id == user_id
def test_create_end_user_web_app_type(self, db_session_with_containers, factory):
def test_create_end_user_web_app_type(self, db_session_with_containers: Session, factory):
"""Test creating new end user with WEB_APP type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -209,7 +210,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result.type == InvokeFrom.WEB_APP
@patch("services.end_user_service.logger")
def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory):
def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers: Session, factory):
"""Test upgrading legacy end user with different type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -243,7 +244,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert "Upgrading legacy EndUser" in log_call
@patch("services.end_user_service.logger")
def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory):
def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers: Session, factory):
"""Test retrieving existing end user with matching type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -272,7 +273,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result.type == InvokeFrom.SERVICE_API
mock_logger.info.assert_not_called()
def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory):
def test_create_anonymous_user_with_default_session(self, db_session_with_containers: Session, factory):
"""Test creating anonymous user when user_id is None."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -293,7 +294,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result._is_anonymous is True
assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory):
def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers: Session, factory):
"""Test that query ordering prioritizes records with matching type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -328,7 +329,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result.id == matching.id
assert result.id != non_matching.id
def test_external_user_id_matches_session_id(self, db_session_with_containers, factory):
def test_external_user_id_matches_session_id(self, db_session_with_containers: Session, factory):
"""Test that external_user_id is set to match session_id."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -357,7 +358,9 @@ class TestEndUserServiceGetOrCreateEndUserByType:
InvokeFrom.DEBUGGER,
],
)
def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory):
def test_create_end_user_with_different_invoke_types(
self, db_session_with_containers: Session, invoke_type, factory
):
"""Test creating end users with different InvokeFrom types."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -385,7 +388,7 @@ class TestEndUserServiceGetEndUserById:
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory):
def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers: Session, factory):
app = factory.create_app_and_account(db_session_with_containers)
existing_user = factory.create_end_user(
db_session_with_containers,
@ -404,7 +407,7 @@ class TestEndUserServiceGetEndUserById:
assert result is not None
assert result.id == existing_user.id
def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory):
def test_get_end_user_by_id_returns_none(self, db_session_with_containers: Session, factory):
app = factory.create_app_and_account(db_session_with_containers)
result = EndUserService.get_end_user_by_id(
@ -423,7 +426,7 @@ class TestEndUserServiceCreateBatch:
def factory(self):
return TestEndUserServiceFactory()
def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3):
def _create_multiple_apps(self, db_session_with_containers: Session, factory, count: int = 3):
"""Create multiple apps under the same tenant."""
first_app = factory.create_app_and_account(db_session_with_containers)
tenant_id = first_app.tenant_id
@ -452,13 +455,13 @@ class TestEndUserServiceCreateBatch:
all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all()
return tenant_id, all_apps
def test_create_batch_empty_app_ids(self, db_session_with_containers):
def test_create_batch_empty_app_ids(self, db_session_with_containers: Session):
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1"
)
assert result == {}
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory):
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
app_ids = [a.id for a in apps]
user_id = f"user-{uuid4()}"
@ -473,7 +476,7 @@ class TestEndUserServiceCreateBatch:
assert result[app_id].session_id == user_id
assert result[app_id].type == InvokeFrom.SERVICE_API
def test_create_batch_default_session_id(self, db_session_with_containers, factory):
def test_create_batch_default_session_id(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [a.id for a in apps]
@ -486,7 +489,7 @@ class TestEndUserServiceCreateBatch:
assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
assert end_user._is_anonymous is True
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory):
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id]
user_id = f"user-{uuid4()}"
@ -497,7 +500,7 @@ class TestEndUserServiceCreateBatch:
assert len(result) == 2
def test_create_batch_returns_existing_users(self, db_session_with_containers, factory):
def test_create_batch_returns_existing_users(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [a.id for a in apps]
user_id = f"user-{uuid4()}"
@ -516,7 +519,7 @@ class TestEndUserServiceCreateBatch:
for app_id in app_ids:
assert first_result[app_id].id == second_result[app_id].id
def test_create_batch_partial_existing_users(self, db_session_with_containers, factory):
def test_create_batch_partial_existing_users(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
user_id = f"user-{uuid4()}"
@ -545,7 +548,7 @@ class TestEndUserServiceCreateBatch:
"invoke_type",
[InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER],
)
def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory):
def test_create_batch_all_invoke_types(self, db_session_with_containers: Session, invoke_type, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1)
user_id = f"user-{uuid4()}"

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from services.feature_service import (
@ -81,7 +82,7 @@ class TestFeatureService:
fake = Faker()
return fake.uuid4()
def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful feature retrieval with billing and enterprise enabled.
@ -156,7 +157,7 @@ class TestFeatureService:
tenant_id
)
def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_sandbox_plan(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test feature retrieval for sandbox plan with specific limitations.
@ -222,7 +223,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_knowledge_rate_limit_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful knowledge rate limit retrieval with billing enabled.
@ -255,7 +258,7 @@ class TestFeatureService:
tenant_id
)
def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful system features retrieval with enterprise and marketplace enabled.
@ -332,7 +335,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_unauthenticated(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval for an unauthenticated user.
@ -386,7 +391,9 @@ class TestFeatureService:
# Marketplace should be visible
assert result.enable_marketplace is True
def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_basic_config(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with basic configuration (no enterprise).
@ -436,7 +443,9 @@ class TestFeatureService:
# Verify plugin package size (uses default value from dify_config)
assert result.max_plugin_package_size == 15728640
def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_billing_disabled(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval when billing is disabled.
@ -492,7 +501,7 @@ class TestFeatureService:
assert result.webapp_copyright_enabled is False
def test_get_knowledge_rate_limit_billing_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test knowledge rate limit retrieval when billing is disabled.
@ -523,7 +532,9 @@ class TestFeatureService:
# Verify no billing service calls
mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called()
def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_enterprise_only(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with enterprise enabled but billing disabled.
@ -583,7 +594,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_not_called()
def test_get_system_features_enterprise_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval when enterprise is disabled.
@ -640,7 +651,7 @@ class TestFeatureService:
# Verify no enterprise service calls
mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called()
def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test feature retrieval without tenant ID (billing disabled).
@ -686,7 +697,9 @@ class TestFeatureService:
# Verify no billing service calls
mock_external_service_dependencies["billing_service"].get_info.assert_not_called()
def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_partial_billing_info(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with partial billing information.
@ -746,7 +759,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_edge_case_vector_space(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case vector space configuration.
@ -807,7 +822,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_system_features_edge_case_webapp_auth(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with edge case webapp auth configuration.
@ -863,7 +878,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_edge_case_members_quota(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case members quota configuration.
@ -924,7 +941,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_plugin_installation_permission_scopes(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with different plugin installation permission scopes.
@ -1023,7 +1040,7 @@ class TestFeatureService:
assert result.plugin_installation_permission.restrict_to_marketplace_only is True
def test_get_features_workspace_members_missing(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval when workspace members info is missing from enterprise.
@ -1064,7 +1081,9 @@ class TestFeatureService:
tenant_id
)
def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_license_inactive(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with inactive license.
@ -1117,7 +1136,7 @@ class TestFeatureService:
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_system_features_partial_enterprise_info(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with partial enterprise information.
@ -1186,7 +1205,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_edge_case_limits(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case limit values.
@ -1244,7 +1265,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_system_features_edge_case_protocols(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with edge case protocol values.
@ -1297,7 +1318,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_edge_case_education(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case education configuration.
@ -1353,7 +1376,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_license_limitation_model_is_available(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test LicenseLimitationModel.is_available method with various scenarios.
@ -1394,7 +1417,7 @@ class TestFeatureService:
assert exact_limit.is_available(3) is True
def test_get_features_workspace_members_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval when workspace members are disabled in enterprise.
@ -1433,7 +1456,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id)
def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_license_expired(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with expired license.
@ -1486,7 +1511,7 @@ class TestFeatureService:
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_docs_processing(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case document processing configuration.
@ -1544,7 +1569,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_system_features_edge_case_branding(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with edge case branding configuration.
@ -1606,7 +1631,7 @@ class TestFeatureService:
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_annotation_quota(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case annotation quota configuration.
@ -1668,7 +1693,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_features_edge_case_documents_upload(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case documents upload settings.
@ -1733,7 +1758,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_system_features_edge_case_license_lost(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features with lost license status.
@ -1784,7 +1809,7 @@ class TestFeatureService:
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_education_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with education feature disabled.

View File

@ -6,6 +6,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from configs import dify_config
from core.workflow.human_input_adapter import (
@ -88,7 +89,7 @@ class TestDeliveryTestRegistry:
with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."):
registry.dispatch(context=context, method=method)
def test_default(self, flask_app_with_containers, db_session_with_containers):
def test_default(self, flask_app_with_containers, db_session_with_containers: Session):
registry = DeliveryTestRegistry.default()
assert len(registry._handlers) == 1
assert isinstance(registry._handlers[0], EmailDeliveryTestHandler)
@ -260,7 +261,7 @@ class TestEmailDeliveryTestHandler:
)
assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"]
def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers):
def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
account = Account(name="Test User", email="member@example.com")
db_session_with_containers.add(account)
@ -282,7 +283,7 @@ class TestEmailDeliveryTestHandler:
)
assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"]
def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers):
def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com")
account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com")

View File

@ -5,6 +5,7 @@ from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from models.dataset import Dataset, DatasetMetadataBinding, Document
from models.enums import DataSourceType, DocumentCreatedFrom
@ -65,7 +66,7 @@ class TestMetadataPartialUpdate:
yield account
def test_partial_update_merges_metadata(
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
):
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
@ -92,7 +93,7 @@ class TestMetadataPartialUpdate:
assert updated_doc.doc_metadata["new_key"] == "new_value"
def test_full_update_replaces_metadata(
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
):
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
@ -119,7 +120,7 @@ class TestMetadataPartialUpdate:
assert "existing_key" not in updated_doc.doc_metadata
def test_partial_update_skips_existing_binding(
self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, user_id, mock_current_account
):
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
@ -159,7 +160,7 @@ class TestMetadataPartialUpdate:
assert len(bindings) == 1
def test_rollback_called_on_commit_failure(
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
):
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(

View File

@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from models.model import OAuthProviderApp
@ -25,7 +26,7 @@ from services.oauth_server import (
class TestOAuthServerServiceGetProviderApp:
"""DB-backed tests for get_oauth_provider_app."""
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
def _create_oauth_provider_app(self, db_session_with_containers: Session, *, client_id: str) -> OAuthProviderApp:
app = OAuthProviderApp(
app_icon="icon.png",
client_id=client_id,
@ -38,7 +39,7 @@ class TestOAuthServerServiceGetProviderApp:
db_session_with_containers.commit()
return app
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers: Session):
client_id = f"client-{uuid4()}"
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
@ -48,7 +49,7 @@ class TestOAuthServerServiceGetProviderApp:
assert result.client_id == client_id
assert result.id == created.id
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers: Session):
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
assert result is None

View File

@ -8,6 +8,7 @@ from datetime import datetime
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from models.workflow import WorkflowPause, WorkflowRun
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
@ -39,7 +40,7 @@ class TestWorkflowRunRestore:
assert result["created_at"].month == 1
assert result["name"] == "test"
def test_restore_table_records_returns_rowcount(self, db_session_with_containers):
def test_restore_table_records_returns_rowcount(self, db_session_with_containers: Session):
"""Restore should return inserted rowcount."""
restore = WorkflowRunRestore()
record_id = str(uuid4())
@ -65,7 +66,7 @@ class TestWorkflowRunRestore:
restored_pause = db_session_with_containers.scalar(select(WorkflowPause).where(WorkflowPause.id == record_id))
assert restored_pause is not None
def test_restore_table_records_unknown_table(self, db_session_with_containers):
def test_restore_table_records_unknown_table(self, db_session_with_containers: Session):
"""Unknown table names should be ignored gracefully."""
restore = WorkflowRunRestore()

View File

@ -1099,38 +1099,39 @@ class TestTagService:
db_session_with_containers, mock_external_service_dependencies
)
# Create tag
tag = self._create_test_tags(
db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1
)[0]
# Create tags
tags = self._create_test_tags(
db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2
)
# Create dataset and bind tag
# Create dataset and bind tags
dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id)
self._create_test_tag_bindings(
db_session_with_containers, mock_external_service_dependencies, [tag], dataset.id, tenant.id
db_session_with_containers, mock_external_service_dependencies, tags, dataset.id, tenant.id
)
# Verify binding exists before deletion
binding_before = (
# Verify bindings exist before deletion
bindings_before = (
db_session_with_containers.query(TagBinding)
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
.first()
.where(TagBinding.tag_id.in_([tag.id for tag in tags]), TagBinding.target_id == dataset.id)
.all()
)
assert binding_before is not None
assert len(bindings_before) == 2
# Act: Execute the method under test
delete_payload = TagBindingDeletePayload(type="knowledge", target_id=dataset.id, tag_id=tag.id)
delete_payload = TagBindingDeletePayload(
type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags]
)
TagService.delete_tag_binding(delete_payload)
# Assert: Verify the expected outcomes
# Verify tag binding was deleted
binding_after = (
# Verify tag bindings were deleted
bindings_after = (
db_session_with_containers.query(TagBinding)
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
.first()
.where(TagBinding.tag_id.in_([tag.id for tag in tags]), TagBinding.target_id == dataset.id)
.all()
)
assert binding_after is None
assert len(bindings_after) == 0
def test_delete_tag_binding_non_existent_binding(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -1156,7 +1157,7 @@ class TestTagService:
app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id)
# Act: Try to delete non-existent binding
delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_id=tag.id)
delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_ids=[tag.id])
TagService.delete_tag_binding(delete_payload)
# Assert: Verify the expected outcomes

View File

@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.datastructures import FileStorage
from models.enums import AppTriggerStatus, AppTriggerType
@ -52,7 +53,7 @@ class TestWebhookService:
}
@pytest.fixture
def test_data(self, db_session_with_containers, mock_external_dependencies):
def test_data(self, db_session_with_containers: Session, mock_external_dependencies):
"""Create test data for webhook service tests."""
fake = Faker()
@ -160,7 +161,7 @@ class TestWebhookService:
"app_trigger": app_trigger,
}
def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers):
def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers: Flask):
"""Test successful retrieval of webhook trigger and workflow."""
webhook_id = test_data["webhook_id"]
@ -175,7 +176,7 @@ class TestWebhookService:
assert node_config["id"] == "webhook_node"
assert node_config["data"].title == "Test Webhook"
def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers):
def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers: Flask):
"""Test webhook trigger not found scenario."""
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Webhook not found"):
@ -421,7 +422,9 @@ class TestWebhookService:
assert result["files"] == {}
def test_trigger_workflow_execution_success(self, test_data, mock_external_dependencies, flask_app_with_containers):
def test_trigger_workflow_execution_success(
self, test_data, mock_external_dependencies, flask_app_with_containers: Flask
):
"""Test successful workflow execution trigger."""
webhook_data = {
"method": "POST",
@ -452,7 +455,7 @@ class TestWebhookService:
mock_external_dependencies["async_service"].trigger_workflow_async.assert_called_once()
def test_trigger_workflow_execution_end_user_service_failure(
self, test_data, mock_external_dependencies, flask_app_with_containers
self, test_data, mock_external_dependencies, flask_app_with_containers: Flask
):
"""Test workflow execution trigger when EndUserService fails."""
webhook_data = {"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}}

View File

@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -165,7 +166,7 @@ class WebhookServiceRelationshipFactory:
class TestWebhookServiceLookupWithContainers:
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -182,7 +183,7 @@ class TestWebhookServiceLookupWithContainers:
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -202,7 +203,7 @@ class TestWebhookServiceLookupWithContainers:
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -222,7 +223,7 @@ class TestWebhookServiceLookupWithContainers:
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -239,7 +240,7 @@ class TestWebhookServiceLookupWithContainers:
WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id)
def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -275,7 +276,7 @@ class TestWebhookServiceLookupWithContainers:
class TestWebhookServiceTriggerExecutionWithContainers:
def test_trigger_workflow_execution_triggers_async_workflow_successfully(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -318,7 +319,7 @@ class TestWebhookServiceTriggerExecutionWithContainers:
assert trigger_args[2].root_node_id == webhook_trigger.node_id
def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -354,7 +355,7 @@ class TestWebhookServiceTriggerExecutionWithContainers:
mock_mark_rate_limited.assert_called_once_with(tenant.id)
def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -386,7 +387,7 @@ class TestWebhookServiceTriggerExecutionWithContainers:
class TestWebhookServiceRelationshipSyncWithContainers:
def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -401,7 +402,7 @@ class TestWebhookServiceRelationshipSyncWithContainers:
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_raises_when_lock_not_acquired(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -418,7 +419,7 @@ class TestWebhookServiceRelationshipSyncWithContainers:
WebhookService.sync_webhook_relationships(app, workflow)
def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -455,7 +456,7 @@ class TestWebhookServiceRelationshipSyncWithContainers:
assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None
def test_sync_webhook_relationships_sets_redis_cache_for_new_record(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -481,7 +482,7 @@ class TestWebhookServiceRelationshipSyncWithContainers:
assert cached_payload["webhook_id"] == "cache-webhook-id-00001"
def test_sync_webhook_relationships_logs_when_lock_release_fails(
self, db_session_with_containers: Session, flask_app_with_containers
self, db_session_with_containers: Session, flask_app_with_containers: Flask
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory

View File

@ -1530,7 +1530,7 @@ class TestWorkflowAppService:
assert result_cross_tenant["total"] == 0
def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()
@ -1543,7 +1543,7 @@ class TestWorkflowAppService:
)
def test_get_paginate_workflow_app_logs_filters_by_account(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()
@ -1558,7 +1558,9 @@ class TestWorkflowAppService:
assert result["total"] >= 0
assert isinstance(result["data"], list)
def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_paginate_workflow_archive_logs(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
service = WorkflowAppService()

View File

@ -45,7 +45,9 @@ class TestWorkflowDraftVariableService:
# WorkflowDraftVariableService doesn't have external dependencies that need mocking
return {}
def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None):
def _create_test_app(
self, db_session_with_containers: Session, mock_external_service_dependencies, fake: Faker | None = None
):
"""
Helper method to create a test app with realistic data for testing.
@ -80,7 +82,7 @@ class TestWorkflowDraftVariableService:
db_session_with_containers.commit()
return app
def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None):
def _create_test_workflow(self, db_session_with_containers: Session, app, fake: Faker | None = None):
"""
Helper method to create a test workflow associated with an app.

View File

@ -12,7 +12,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models import Account, App, Workflow
from models import Account, AccountStatus, App, TenantStatus, Workflow
from models.model import AppMode
from models.workflow import WorkflowType
from services.workflow_service import WorkflowService
@ -33,7 +33,7 @@ class TestWorkflowService:
and realistic testing environment with actual database interactions.
"""
def _create_test_account(self, db_session_with_containers: Session, fake=None):
def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None):
"""
Helper method to create a test account with realistic data.
@ -49,7 +49,7 @@ class TestWorkflowService:
email=fake.email(),
name=fake.name(),
avatar=fake.url(),
status="active",
status=AccountStatus.ACTIVE,
interface_language="en-US", # Set interface language for Site creation
)
account.created_at = fake.date_time_this_year()
@ -62,7 +62,7 @@ class TestWorkflowService:
tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="normal",
status=TenantStatus.NORMAL,
)
tenant.id = account.current_tenant_id
tenant.created_at = fake.date_time_this_year()
@ -77,7 +77,7 @@ class TestWorkflowService:
return account
def _create_test_app(self, db_session_with_containers: Session, fake=None):
def _create_test_app(self, db_session_with_containers: Session, fake: Faker | None = None):
"""
Helper method to create a test app with realistic data.
@ -109,7 +109,7 @@ class TestWorkflowService:
db_session_with_containers.commit()
return app
def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None):
def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake: Faker | None = None):
"""
Helper method to create a test workflow associated with an app.

View File

@ -100,7 +100,7 @@ class TestWorkflowDeletion:
session.flush()
return provider
def test_delete_workflow_success(self, db_session_with_containers):
def test_delete_workflow_success(self, db_session_with_containers: Session):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
@ -118,7 +118,7 @@ class TestWorkflowDeletion:
db_session_with_containers.expire_all()
assert db_session_with_containers.get(Workflow, workflow_id) is None
def test_delete_draft_workflow_raises_error(self, db_session_with_containers):
def test_delete_draft_workflow_raises_error(self, db_session_with_containers: Session):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
@ -130,7 +130,7 @@ class TestWorkflowDeletion:
with pytest.raises(DraftWorkflowDeletionError):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers):
def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers: Session):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(
@ -144,7 +144,7 @@ class TestWorkflowDeletion:
with pytest.raises(WorkflowInUseError, match="currently in use by app"):
service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id)
def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers):
def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers: Session):
tenant, account = self._create_tenant_and_account(db_session_with_containers)
app = self._create_app(db_session_with_containers, tenant=tenant, account=account)
workflow = self._create_workflow(

View File

@ -64,7 +64,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
db_session_with_containers.commit()
return execution
def test_get_node_last_execution_found(self, db_session_with_containers):
def test_get_node_last_execution_found(self, db_session_with_containers: Session):
"""Test getting the last execution for a node when it exists."""
# Arrange
tenant_id = str(uuid4())
@ -110,7 +110,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert result.id == expected.id
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
def test_get_node_last_execution_not_found(self, db_session_with_containers):
def test_get_node_last_execution_not_found(self, db_session_with_containers: Session):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
tenant_id = str(uuid4())
@ -129,7 +129,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result is None
def test_get_executions_by_workflow_run_empty(self, db_session_with_containers):
def test_get_executions_by_workflow_run_empty(self, db_session_with_containers: Session):
"""Test getting executions for a workflow run when none exist."""
# Arrange
tenant_id = str(uuid4())
@ -147,7 +147,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result == []
def test_get_execution_by_id_found(self, db_session_with_containers):
def test_get_execution_by_id_found(self, db_session_with_containers: Session):
"""Test getting execution by ID when it exists."""
# Arrange
execution = self._create_execution(
@ -170,7 +170,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert result is not None
assert result.id == execution.id
def test_get_execution_by_id_not_found(self, db_session_with_containers):
def test_get_execution_by_id_not_found(self, db_session_with_containers: Session):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
repository = self._create_repository(db_session_with_containers)
@ -182,7 +182,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result is None
def test_delete_expired_executions(self, db_session_with_containers):
def test_delete_expired_executions(self, db_session_with_containers: Session):
"""Test deleting expired executions."""
# Arrange
tenant_id = str(uuid4())
@ -248,7 +248,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert old_execution_2_id not in remaining_ids
assert kept_execution_id in remaining_ids
def test_delete_executions_by_app(self, db_session_with_containers):
def test_delete_executions_by_app(self, db_session_with_containers: Session):
"""Test deleting executions by app."""
# Arrange
tenant_id = str(uuid4())
@ -313,7 +313,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert deleted_2_id not in remaining_ids
assert kept_id in remaining_ids
def test_get_expired_executions_batch(self, db_session_with_containers):
def test_get_expired_executions_batch(self, db_session_with_containers: Session):
"""Test getting expired executions batch for backup."""
# Arrange
tenant_id = str(uuid4())
@ -370,7 +370,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert old_execution_1.id in result_ids
assert old_execution_2.id in result_ids
def test_delete_executions_by_ids(self, db_session_with_containers):
def test_delete_executions_by_ids(self, db_session_with_containers: Session):
"""Test deleting executions by IDs."""
# Arrange
tenant_id = str(uuid4())
@ -424,7 +424,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
).all()
assert remaining == []
def test_delete_executions_by_ids_empty_list(self, db_session_with_containers):
def test_delete_executions_by_ids_empty_list(self, db_session_with_containers: Session):
"""Test deleting executions with empty ID list."""
# Arrange
repository = self._create_repository(db_session_with_containers)

View File

@ -71,7 +71,7 @@ class TestCleanNotionDocumentTask:
yield mock_factory
def test_clean_notion_document_task_success(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test successful cleanup of Notion documents with proper database operations.
@ -176,7 +176,7 @@ class TestCleanNotionDocumentTask:
# 5. The task completes without errors
def test_clean_notion_document_task_dataset_not_found(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task behavior when dataset is not found.
@ -196,7 +196,7 @@ class TestCleanNotionDocumentTask:
mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
def test_clean_notion_document_task_empty_document_list(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task behavior with empty document list.
@ -240,7 +240,7 @@ class TestCleanNotionDocumentTask:
assert args[1] == []
def test_clean_notion_document_task_with_different_index_types(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task with different dataset index types.
@ -328,7 +328,7 @@ class TestCleanNotionDocumentTask:
mock_index_processor_factory.reset_mock()
def test_clean_notion_document_task_with_segments_no_index_node_ids(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task with segments that have no index_node_ids.
@ -411,7 +411,7 @@ class TestCleanNotionDocumentTask:
# are properly deleted from the database.
def test_clean_notion_document_task_partial_document_cleanup(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task with partial document cleanup scenario.
@ -513,7 +513,7 @@ class TestCleanNotionDocumentTask:
# The database operations work correctly, isolating only the specified documents.
def test_clean_notion_document_task_with_mixed_segment_statuses(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task with segments in different statuses.
@ -603,7 +603,7 @@ class TestCleanNotionDocumentTask:
# IndexProcessor verification would require more sophisticated mocking.
def test_clean_notion_document_task_continues_when_index_processor_fails(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Index processor failure (e.g. transient billing API error propagated via
@ -707,7 +707,7 @@ class TestCleanNotionDocumentTask:
assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0
def test_clean_notion_document_task_with_large_number_of_documents(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task with a large number of documents and segments.
@ -806,7 +806,7 @@ class TestCleanNotionDocumentTask:
# The database efficiently handles large-scale deletions.
def test_clean_notion_document_task_with_documents_from_different_tenants(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task with documents from different tenants.
@ -918,7 +918,7 @@ class TestCleanNotionDocumentTask:
# Only documents from the target dataset are affected, maintaining tenant separation.
def test_clean_notion_document_task_with_documents_in_different_states(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task with documents in different indexing states.
@ -1024,7 +1024,7 @@ class TestCleanNotionDocumentTask:
# All documents are deleted regardless of their indexing status.
def test_clean_notion_document_task_with_documents_having_metadata(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies
):
"""
Test cleanup task with documents that have rich metadata.

View File

@ -12,6 +12,7 @@ from uuid import uuid4
import pytest
from faker import Faker
from sqlalchemy import delete
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from extensions.ext_redis import redis_client
@ -25,7 +26,7 @@ class TestCreateSegmentToIndexTask:
"""Integration tests for create_segment_to_index_task using testcontainers."""
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database and Redis before each test to ensure isolation."""
# Clear all test data using fixture session
@ -55,7 +56,7 @@ class TestCreateSegmentToIndexTask:
"index_processor": mock_processor,
}
def _create_test_account_and_tenant(self, db_session_with_containers):
def _create_test_account_and_tenant(self, db_session_with_containers: Session):
"""
Helper method to create a test account and tenant for testing.
@ -102,7 +103,7 @@ class TestCreateSegmentToIndexTask:
return account, tenant
def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id):
def _create_test_dataset_and_document(self, db_session_with_containers: Session, tenant_id, account_id):
"""
Helper method to create a test dataset and document for testing.
@ -151,7 +152,13 @@ class TestCreateSegmentToIndexTask:
return dataset, document
def _create_test_segment(
self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING
self,
db_session_with_containers: Session,
dataset_id,
document_id,
tenant_id,
account_id,
status=SegmentStatus.WAITING,
):
"""
Helper method to create a test document segment for testing.
@ -189,7 +196,9 @@ class TestCreateSegmentToIndexTask:
return segment
def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_segment_to_index_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful creation of segment to index.
@ -225,7 +234,7 @@ class TestCreateSegmentToIndexTask:
assert redis_client.exists(cache_key) == 0
def test_create_segment_to_index_segment_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of non-existent segment ID.
@ -246,7 +255,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
def test_create_segment_to_index_invalid_status(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of segment with invalid status.
@ -277,7 +286,9 @@ class TestCreateSegmentToIndexTask:
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_segment_to_index_no_dataset(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of segment without associated dataset.
@ -330,7 +341,9 @@ class TestCreateSegmentToIndexTask:
# Verify no index processor calls were made
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_segment_to_index_no_document(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of segment without associated document.
@ -367,7 +380,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
def test_create_segment_to_index_document_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of segment with disabled document.
@ -403,7 +416,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
def test_create_segment_to_index_document_archived(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of segment with archived document.
@ -439,7 +452,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
def test_create_segment_to_index_document_indexing_incomplete(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of segment with document that has incomplete indexing.
@ -475,7 +488,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor_factory"].assert_not_called()
def test_create_segment_to_index_processor_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of index processor exceptions.
@ -511,7 +524,7 @@ class TestCreateSegmentToIndexTask:
assert redis_client.exists(cache_key) == 0
def test_create_segment_to_index_with_keywords(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with custom keywords.
@ -543,7 +556,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_called_once()
def test_create_segment_to_index_different_doc_forms(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with different document forms.
@ -586,7 +599,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form)
def test_create_segment_to_index_performance_timing(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing performance and timing.
@ -617,7 +630,7 @@ class TestCreateSegmentToIndexTask:
assert segment.status == SegmentStatus.COMPLETED
def test_create_segment_to_index_concurrent_execution(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test concurrent execution of segment indexing tasks.
@ -654,7 +667,7 @@ class TestCreateSegmentToIndexTask:
assert mock_external_service_dependencies["index_processor_factory"].call_count == 3
def test_create_segment_to_index_large_content(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with large content.
@ -703,7 +716,7 @@ class TestCreateSegmentToIndexTask:
assert segment.completed_at is not None
def test_create_segment_to_index_redis_failure(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing when Redis operations fail.
@ -743,7 +756,7 @@ class TestCreateSegmentToIndexTask:
assert redis_client.exists(cache_key) == 1
def test_create_segment_to_index_database_transaction_rollback(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with database transaction handling.
@ -775,7 +788,7 @@ class TestCreateSegmentToIndexTask:
assert segment.error is not None
def test_create_segment_to_index_metadata_validation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with metadata validation.
@ -817,7 +830,7 @@ class TestCreateSegmentToIndexTask:
assert doc is not None
def test_create_segment_to_index_status_transition_flow(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test complete status transition flow during indexing.
@ -852,7 +865,7 @@ class TestCreateSegmentToIndexTask:
assert segment.indexing_at <= segment.completed_at
def test_create_segment_to_index_with_empty_content(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with empty or minimal content.
@ -894,7 +907,7 @@ class TestCreateSegmentToIndexTask:
assert segment.completed_at is not None
def test_create_segment_to_index_with_special_characters(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with special characters and unicode content.
@ -940,7 +953,7 @@ class TestCreateSegmentToIndexTask:
assert segment.completed_at is not None
def test_create_segment_to_index_with_long_keywords(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with long keyword lists.
@ -974,7 +987,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_called_once()
def test_create_segment_to_index_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with proper tenant isolation.
@ -1017,7 +1030,7 @@ class TestCreateSegmentToIndexTask:
assert segment1.tenant_id != segment2.tenant_id
def test_create_segment_to_index_with_none_keywords(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment indexing with None keywords parameter.
@ -1048,7 +1061,7 @@ class TestCreateSegmentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_called_once()
def test_create_segment_to_index_comprehensive_integration(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Comprehensive integration test covering multiple scenarios.

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@ -174,11 +175,11 @@ class TestDatasetIndexingTaskIntegration:
return dataset, documents
def _query_document(self, db_session_with_containers, document_id: str) -> Document | None:
def _query_document(self, db_session_with_containers: Session, document_id: str) -> Document | None:
"""Return the latest persisted document state."""
return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1))
def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None:
def _assert_documents_parsing(self, db_session_with_containers: Session, document_ids: Sequence[str]) -> None:
"""Assert all target documents are persisted in parsing status."""
db_session_with_containers.expire_all()
for document_id in document_ids:
@ -212,7 +213,9 @@ class TestDatasetIndexingTaskIntegration:
assert len(opened) >= 2
assert opened_ids <= closed_ids
def test_legacy_document_indexing_task_still_works(self, db_session_with_containers, patched_external_dependencies):
def test_legacy_document_indexing_task_still_works(
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Ensure the legacy task entrypoint still updates parsing status."""
# Arrange
dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2)
@ -225,7 +228,9 @@ class TestDatasetIndexingTaskIntegration:
patched_external_dependencies["indexing_runner_instance"].run.assert_called_once()
self._assert_documents_parsing(db_session_with_containers, document_ids)
def test_batch_processing_multiple_documents(self, db_session_with_containers, patched_external_dependencies):
def test_batch_processing_multiple_documents(
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Process multiple documents in one batch."""
# Arrange
dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3)
@ -240,7 +245,9 @@ class TestDatasetIndexingTaskIntegration:
assert len(run_args) == len(document_ids)
self._assert_documents_parsing(db_session_with_containers, document_ids)
def test_batch_processing_with_limit_check(self, db_session_with_containers, patched_external_dependencies):
def test_batch_processing_with_limit_check(
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Reject batches larger than configured upload limit.
This test patches config only to force a deterministic limit branch while keeping SQL writes real.
@ -263,7 +270,7 @@ class TestDatasetIndexingTaskIntegration:
self._assert_documents_error_contains(db_session_with_containers, document_ids, "batch upload limit")
def test_batch_processing_sandbox_plan_single_document_only(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Reject multi-document upload under sandbox plan."""
# Arrange
@ -280,7 +287,9 @@ class TestDatasetIndexingTaskIntegration:
patched_external_dependencies["indexing_runner_instance"].run.assert_not_called()
self._assert_documents_error_contains(db_session_with_containers, document_ids, "does not support batch upload")
def test_batch_processing_empty_document_list(self, db_session_with_containers, patched_external_dependencies):
def test_batch_processing_empty_document_list(
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Handle empty list input without failing."""
# Arrange
dataset, _ = self._create_test_dataset_and_documents(db_session_with_containers, document_count=0)
@ -292,7 +301,7 @@ class TestDatasetIndexingTaskIntegration:
patched_external_dependencies["indexing_runner_instance"].run.assert_called_once_with([])
def test_tenant_queue_dispatches_next_task_after_completion(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Dispatch the next queued task after current tenant task completes.
@ -337,7 +346,7 @@ class TestDatasetIndexingTaskIntegration:
delete_key_spy.assert_not_called()
def test_tenant_queue_deletes_running_key_when_no_follow_up_tasks(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Delete tenant running flag when queue has no pending tasks.
@ -362,7 +371,7 @@ class TestDatasetIndexingTaskIntegration:
delete_key_spy.assert_called_once()
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Set error status when vector space validation fails before runner phase."""
# Arrange
@ -382,7 +391,7 @@ class TestDatasetIndexingTaskIntegration:
self._assert_documents_error_contains(db_session_with_containers, document_ids, "over the limit")
def test_runner_exception_does_not_crash_indexing_task(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Catch generic runner exceptions without crashing the task."""
# Arrange
@ -397,7 +406,7 @@ class TestDatasetIndexingTaskIntegration:
patched_external_dependencies["indexing_runner_instance"].run.assert_called_once()
self._assert_documents_parsing(db_session_with_containers, document_ids)
def test_document_paused_error_handling(self, db_session_with_containers, patched_external_dependencies):
def test_document_paused_error_handling(self, db_session_with_containers: Session, patched_external_dependencies):
"""Handle DocumentIsPausedError and keep persisted state consistent."""
# Arrange
dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2)
@ -424,7 +433,7 @@ class TestDatasetIndexingTaskIntegration:
patched_external_dependencies["indexing_runner_instance"].run.assert_not_called()
def test_tenant_queue_error_handling_still_processes_next_task(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Even on current task failure, enqueue the next waiting tenant task.
@ -491,7 +500,7 @@ class TestDatasetIndexingTaskIntegration:
self._assert_all_opened_sessions_closed(session_close_tracker)
def test_multiple_documents_with_mixed_success_and_failure(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Process only existing documents when request includes missing ids."""
# Arrange
@ -508,7 +517,7 @@ class TestDatasetIndexingTaskIntegration:
self._assert_documents_parsing(db_session_with_containers, existing_ids)
def test_tenant_queue_dispatches_up_to_concurrency_limit(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Dispatch only up to configured concurrency under queued backlog burst.
@ -543,7 +552,7 @@ class TestDatasetIndexingTaskIntegration:
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
assert set_waiting_spy.call_count == concurrency_limit
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
def test_task_queue_fifo_ordering(self, db_session_with_containers: Session, patched_external_dependencies):
"""Keep FIFO ordering when dispatching next queued tasks.
Queue APIs are patched to isolate dispatch side effects while preserving DB assertions.
@ -576,7 +585,9 @@ class TestDatasetIndexingTaskIntegration:
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
def test_billing_disabled_skips_limit_checks(
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Skip limit checks when billing feature is disabled."""
# Arrange
large_document_ids = [str(uuid.uuid4()) for _ in range(100)]
@ -595,7 +606,7 @@ class TestDatasetIndexingTaskIntegration:
assert len(run_args) == 100
self._assert_documents_parsing(db_session_with_containers, large_document_ids)
def test_complete_workflow_normal_task(self, db_session_with_containers, patched_external_dependencies):
def test_complete_workflow_normal_task(self, db_session_with_containers: Session, patched_external_dependencies):
"""Run end-to-end normal queue workflow with tenant queue cleanup.
Queue APIs are patched to isolate dispatch side effects while preserving DB assertions.
@ -618,7 +629,7 @@ class TestDatasetIndexingTaskIntegration:
self._assert_documents_parsing(db_session_with_containers, document_ids)
delete_key_spy.assert_called_once()
def test_complete_workflow_priority_task(self, db_session_with_containers, patched_external_dependencies):
def test_complete_workflow_priority_task(self, db_session_with_containers: Session, patched_external_dependencies):
"""Run end-to-end priority queue workflow with tenant queue cleanup.
Queue APIs are patched to isolate dispatch side effects while preserving DB assertions.
@ -641,7 +652,7 @@ class TestDatasetIndexingTaskIntegration:
self._assert_documents_parsing(db_session_with_containers, document_ids)
delete_key_spy.assert_called_once()
def test_single_document_processing(self, db_session_with_containers, patched_external_dependencies):
def test_single_document_processing(self, db_session_with_containers: Session, patched_external_dependencies):
"""Process the minimum batch size (single document)."""
# Arrange
dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=1)
@ -655,7 +666,9 @@ class TestDatasetIndexingTaskIntegration:
assert len(run_args) == 1
self._assert_documents_parsing(db_session_with_containers, [document_id])
def test_document_with_special_characters_in_id(self, db_session_with_containers, patched_external_dependencies):
def test_document_with_special_characters_in_id(
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Handle standard UUID ids with hyphen characters safely."""
# Arrange
special_document_id = str(uuid.uuid4())
@ -670,7 +683,9 @@ class TestDatasetIndexingTaskIntegration:
# Assert
self._assert_documents_parsing(db_session_with_containers, [special_document_id])
def test_zero_vector_space_limit_allows_unlimited(self, db_session_with_containers, patched_external_dependencies):
def test_zero_vector_space_limit_allows_unlimited(
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Treat vector limit 0 as unlimited and continue indexing."""
# Arrange
dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3)
@ -689,7 +704,7 @@ class TestDatasetIndexingTaskIntegration:
self._assert_documents_parsing(db_session_with_containers, document_ids)
def test_negative_vector_space_values_handled_gracefully(
self, db_session_with_containers, patched_external_dependencies
self, db_session_with_containers: Session, patched_external_dependencies
):
"""Treat negative vector limits as non-blocking and continue indexing."""
# Arrange
@ -708,7 +723,7 @@ class TestDatasetIndexingTaskIntegration:
patched_external_dependencies["indexing_runner_instance"].run.assert_called_once()
self._assert_documents_parsing(db_session_with_containers, document_ids)
def test_large_document_batch_processing(self, db_session_with_containers, patched_external_dependencies):
def test_large_document_batch_processing(self, db_session_with_containers: Session, patched_external_dependencies):
"""Process a batch exactly at configured upload limit.
This test patches config only to force a deterministic limit branch while keeping SQL writes real.

View File

@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from models.dataset import Dataset, Document, DocumentSegment
@ -55,7 +56,7 @@ class TestDealDatasetVectorIndexTask:
yield mock_factory
@pytest.fixture
def account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""Create an account with an owner tenant for testing.
Returns a tuple of (account, tenant) where tenant is guaranteed to be non-None.
@ -73,7 +74,7 @@ class TestDealDatasetVectorIndexTask:
return account, tenant
def test_deal_dataset_vector_index_task_remove_action_success(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test successful removal of dataset vector index.
@ -131,7 +132,7 @@ class TestDealDatasetVectorIndexTask:
assert mock_processor.clean.call_count >= 0 # For now, just check it doesn't fail
def test_deal_dataset_vector_index_task_add_action_success(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test successful addition of dataset vector index.
@ -233,7 +234,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_called_once()
def test_deal_dataset_vector_index_task_update_action_success(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test successful update of dataset vector index.
@ -337,7 +338,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_called_once()
def test_deal_dataset_vector_index_task_dataset_not_found_error(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test task behavior when dataset is not found.
@ -357,7 +358,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_not_called()
def test_deal_dataset_vector_index_task_add_action_no_documents(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test add action when no documents exist for the dataset.
@ -389,7 +390,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_not_called()
def test_deal_dataset_vector_index_task_add_action_no_segments(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test add action when documents exist but have no segments.
@ -447,7 +448,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_not_called()
def test_deal_dataset_vector_index_task_update_action_no_documents(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test update action when no documents exist for the dataset.
@ -480,7 +481,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_not_called()
def test_deal_dataset_vector_index_task_add_action_with_exception_handling(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test add action with exception handling during processing.
@ -578,7 +579,7 @@ class TestDealDatasetVectorIndexTask:
assert "Test exception during indexing" in updated_document.error
def test_deal_dataset_vector_index_task_with_custom_index_type(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test task behavior with custom index type (QA_INDEX).
@ -656,7 +657,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_called_once()
def test_deal_dataset_vector_index_task_with_default_index_type(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test task behavior with default index type (PARAGRAPH_INDEX).
@ -734,7 +735,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_called_once()
def test_deal_dataset_vector_index_task_multiple_documents_processing(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test task processing with multiple documents and segments.
@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask:
assert mock_processor.load.call_count == 3
def test_deal_dataset_vector_index_task_document_status_transitions(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test document status transitions during task execution.
@ -938,7 +939,7 @@ class TestDealDatasetVectorIndexTask:
assert updated_document.indexing_status == IndexingStatus.COMPLETED
def test_deal_dataset_vector_index_task_with_disabled_documents(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test task behavior with disabled documents.
@ -1061,7 +1062,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_called_once()
def test_deal_dataset_vector_index_task_with_archived_documents(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test task behavior with archived documents.
@ -1184,7 +1185,7 @@ class TestDealDatasetVectorIndexTask:
mock_processor.load.assert_called_once()
def test_deal_dataset_vector_index_task_with_incomplete_documents(
self, db_session_with_containers, mock_index_processor_factory, account_and_tenant
self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant
):
"""
Test task behavior with documents that have incomplete indexing status.

View File

@ -11,9 +11,19 @@ import logging
from unittest.mock import MagicMock, patch
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Dataset, Document, DocumentSegment, Tenant
from models import (
Account,
AccountStatus,
Dataset,
DatasetPermissionEnum,
Document,
DocumentSegment,
Tenant,
TenantStatus,
)
from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
@ -37,7 +47,7 @@ class TestDeleteSegmentFromIndexTask:
and realistic testing environment with actual database interactions.
"""
def _create_test_tenant(self, db_session_with_containers, fake=None):
def _create_test_tenant(self, db_session_with_containers: Session, fake: Faker | None = None):
"""
Helper method to create a test tenant with realistic data.
@ -49,7 +59,7 @@ class TestDeleteSegmentFromIndexTask:
Tenant: Created test tenant instance
"""
fake = fake or Faker()
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal")
tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status=TenantStatus.NORMAL)
tenant.id = fake.uuid4()
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
@ -58,7 +68,7 @@ class TestDeleteSegmentFromIndexTask:
db_session_with_containers.commit()
return tenant
def _create_test_account(self, db_session_with_containers, tenant, fake=None):
def _create_test_account(self, db_session_with_containers: Session, tenant, fake: Faker | None = None):
"""
Helper method to create a test account with realistic data.
@ -75,7 +85,7 @@ class TestDeleteSegmentFromIndexTask:
name=fake.name(),
email=fake.email(),
avatar=fake.url(),
status="active",
status=AccountStatus.ACTIVE,
interface_language="en-US",
)
account.id = fake.uuid4()
@ -86,7 +96,9 @@ class TestDeleteSegmentFromIndexTask:
db_session_with_containers.commit()
return account
def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None):
def _create_test_dataset(
self, db_session_with_containers: Session, tenant: Tenant, account: Account, fake: Faker | None = None
):
"""
Helper method to create a test dataset with realistic data.
@ -106,7 +118,7 @@ class TestDeleteSegmentFromIndexTask:
dataset.name = f"Test Dataset {fake.word()}"
dataset.description = fake.text(max_nb_chars=200)
dataset.provider = "vendor"
dataset.permission = "only_me"
dataset.permission = DatasetPermissionEnum.ONLY_ME
dataset.data_source_type = DataSourceType.UPLOAD_FILE
dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY
dataset.index_struct = '{"type": "paragraph"}'
@ -122,7 +134,7 @@ class TestDeleteSegmentFromIndexTask:
db_session_with_containers.commit()
return dataset
def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs):
def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None, **kwargs):
"""
Helper method to create a test document with realistic data.
@ -172,7 +184,14 @@ class TestDeleteSegmentFromIndexTask:
db_session_with_containers.commit()
return document
def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None):
def _create_test_document_segments(
self,
db_session_with_containers: Session,
document: Document,
account: Account,
count: int = 3,
fake: Faker | None = None,
):
"""
Helper method to create test document segments with realistic data.
@ -218,7 +237,9 @@ class TestDeleteSegmentFromIndexTask:
return segments
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers):
def test_delete_segment_from_index_task_success(
self, mock_index_processor_factory, db_session_with_containers: Session
):
"""
Test successful segment deletion from index with comprehensive verification.
@ -267,7 +288,7 @@ class TestDeleteSegmentFromIndexTask:
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is True
def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers):
def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers: Session):
"""
Test task behavior when dataset is not found.
@ -288,7 +309,7 @@ class TestDeleteSegmentFromIndexTask:
# Verify the task completed without exceptions
assert result is None # Task should return None when dataset not found
def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers):
def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers: Session):
"""
Test task behavior when document is not found.
@ -314,7 +335,7 @@ class TestDeleteSegmentFromIndexTask:
# Verify the task completed without exceptions
assert result is None # Task should return None when document not found
def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers):
def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers: Session):
"""
Test task behavior when document is disabled.
@ -342,7 +363,7 @@ class TestDeleteSegmentFromIndexTask:
# Verify the task completed without exceptions
assert result is None # Task should return None when document is disabled
def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers):
def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers: Session):
"""
Test task behavior when document is archived.
@ -370,7 +391,7 @@ class TestDeleteSegmentFromIndexTask:
# Verify the task completed without exceptions
assert result is None # Task should return None when document is archived
def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers):
def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers: Session):
"""
Test task behavior when document indexing is not completed.

View File

@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Dataset, DocumentSegment
from models import Account, AccountStatus, Dataset, DocumentSegment, TenantAccountRole, TenantStatus
from models import Document as DatasetDocument
from models.dataset import DatasetProcessRule
from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus
@ -35,7 +35,7 @@ class TestDisableSegmentsFromIndexTask:
and realistic testing environment with actual database interactions.
"""
def _create_test_account(self, db_session_with_containers: Session, fake=None):
def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None):
"""
Helper method to create a test account with realistic data.
@ -51,24 +51,23 @@ class TestDisableSegmentsFromIndexTask:
email=fake.email(),
name=fake.name(),
avatar=fake.url(),
status="active",
status=AccountStatus.ACTIVE,
interface_language="en-US",
)
account.id = fake.uuid4()
# monkey-patch attributes for test setup
account.updated_at = fake.date_time_this_year()
account.created_at = fake.date_time_this_year()
account.role = TenantAccountRole.OWNER
account.id = fake.uuid4()
account.tenant_id = fake.uuid4()
account.type = "normal"
account.role = "owner"
account.created_at = fake.date_time_this_year()
account.updated_at = account.created_at
# Create a tenant for the account
from models.account import Tenant
tenant = Tenant(
name=f"Test Tenant {fake.company()}",
plan="basic",
status="normal",
status=TenantStatus.NORMAL,
)
tenant.id = account.tenant_id
tenant.created_at = fake.date_time_this_year()
@ -83,7 +82,7 @@ class TestDisableSegmentsFromIndexTask:
return account
def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None):
def _create_test_dataset(self, db_session_with_containers: Session, account, fake: Faker | None = None):
"""
Helper method to create a test dataset with realistic data.
@ -117,7 +116,9 @@ class TestDisableSegmentsFromIndexTask:
return dataset
def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None):
def _create_test_document(
self, db_session_with_containers: Session, dataset, account: Account, fake: Faker | None = None
):
"""
Helper method to create a test document with realistic data.
@ -216,7 +217,7 @@ class TestDisableSegmentsFromIndexTask:
return segments
def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None):
def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake: Faker | None = None):
"""
Helper method to create a dataset process rule.

View File

@ -13,6 +13,7 @@ from uuid import uuid4
import pytest
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import Session
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -162,7 +163,7 @@ class TestDocumentIndexingSyncTask:
"indexing_runner": indexing_runner,
}
def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None):
def _create_notion_sync_context(self, db_session_with_containers: Session, *, data_source_info: dict | None = None):
account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset(
db_session_with_containers,
@ -206,7 +207,7 @@ class TestDocumentIndexingSyncTask:
"notion_info": notion_info,
}
def test_document_not_found(self, db_session_with_containers, mock_external_dependencies):
def test_document_not_found(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test that task handles missing document gracefully."""
# Arrange
dataset_id = str(uuid4())
@ -219,7 +220,7 @@ class TestDocumentIndexingSyncTask:
mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called()
mock_external_dependencies["indexing_runner"].run.assert_not_called()
def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies):
def test_missing_notion_workspace_id(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test that task raises error when notion_workspace_id is missing."""
# Arrange
context = self._create_notion_sync_context(
@ -235,7 +236,7 @@ class TestDocumentIndexingSyncTask:
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(context["dataset"].id, context["document"].id)
def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies):
def test_missing_notion_page_id(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test that task raises error when notion_page_id is missing."""
# Arrange
context = self._create_notion_sync_context(
@ -251,7 +252,7 @@ class TestDocumentIndexingSyncTask:
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(context["dataset"].id, context["document"].id)
def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies):
def test_empty_data_source_info(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test that task raises error when data_source_info is empty."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None)
@ -264,7 +265,7 @@ class TestDocumentIndexingSyncTask:
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(context["dataset"].id, context["document"].id)
def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies):
def test_credential_not_found(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test that task sets document error state when credential is missing."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
@ -284,7 +285,7 @@ class TestDocumentIndexingSyncTask:
assert updated_document.stopped_at is not None
mock_external_dependencies["indexing_runner"].run.assert_not_called()
def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies):
def test_page_not_updated(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test that task exits early when notion page is unchanged."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
@ -310,7 +311,7 @@ class TestDocumentIndexingSyncTask:
mock_external_dependencies["index_processor"].clean.assert_not_called()
mock_external_dependencies["indexing_runner"].run.assert_not_called()
def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies):
def test_successful_sync_when_page_updated(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test full successful sync flow with SQL state updates and side effects."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
@ -349,7 +350,7 @@ class TestDocumentIndexingSyncTask:
assert len(run_documents) == 1
assert getattr(run_documents[0], "id", None) == context["document"].id
def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies):
def test_dataset_not_found_during_cleaning(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test that task still updates document and reindexes if dataset vanishes before clean."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
@ -376,7 +377,9 @@ class TestDocumentIndexingSyncTask:
mock_external_dependencies["index_processor"].clean.assert_not_called()
mock_external_dependencies["indexing_runner"].run.assert_called_once()
def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies):
def test_cleaning_error_continues_to_indexing(
self, db_session_with_containers: Session, mock_external_dependencies
):
"""Test that indexing continues when index cleanup fails."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
@ -400,7 +403,9 @@ class TestDocumentIndexingSyncTask:
assert remaining_segments == 0
mock_external_dependencies["indexing_runner"].run.assert_called_once()
def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies):
def test_indexing_runner_document_paused_error(
self, db_session_with_containers: Session, mock_external_dependencies
):
"""Test that DocumentIsPausedError does not flip document into error state."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
@ -418,7 +423,7 @@ class TestDocumentIndexingSyncTask:
assert updated_document.indexing_status == IndexingStatus.PARSING
assert updated_document.error is None
def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies):
def test_indexing_runner_general_error(self, db_session_with_containers: Session, mock_external_dependencies):
"""Test that indexing errors are persisted to document state."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)

View File

@ -3,11 +3,12 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.entities.document_task import DocumentTask
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from enums.cloud_plan import CloudPlan
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
from models.dataset import Dataset, Document
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
from tasks.document_indexing_task import (
@ -51,7 +52,7 @@ class TestDocumentIndexingTasks:
}
def _create_test_dataset_and_documents(
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3
):
"""
Helper method to create a test dataset and documents for testing.
@ -71,14 +72,14 @@ class TestDocumentIndexingTasks:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
status=TenantStatus.NORMAL,
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
@ -133,7 +134,7 @@ class TestDocumentIndexingTasks:
return dataset, documents
def _create_test_dataset_with_billing_features(
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True
):
"""
Helper method to create a test dataset with billing features configured.
@ -153,14 +154,14 @@ class TestDocumentIndexingTasks:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
status=TenantStatus.NORMAL,
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
@ -221,7 +222,9 @@ class TestDocumentIndexingTasks:
return dataset, documents
def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_document_indexing_task_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful document indexing with multiple documents.
@ -262,7 +265,7 @@ class TestDocumentIndexingTasks:
assert len(processed_documents) == 3
def test_document_indexing_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of non-existent dataset.
@ -286,7 +289,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
def test_document_indexing_task_document_not_found_in_dataset(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling when some documents don't exist in the dataset.
@ -332,7 +335,7 @@ class TestDocumentIndexingTasks:
assert len(processed_documents) == 2 # Only existing documents
def test_document_indexing_task_indexing_runner_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of IndexingRunner exceptions.
@ -373,7 +376,7 @@ class TestDocumentIndexingTasks:
assert updated_document.processing_started_at is not None
def test_document_indexing_task_mixed_document_states(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test processing documents with mixed initial states.
@ -456,7 +459,7 @@ class TestDocumentIndexingTasks:
assert len(processed_documents) == 4
def test_document_indexing_task_billing_sandbox_plan_batch_limit(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test billing validation for sandbox plan batch upload limit.
@ -518,7 +521,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner"].assert_not_called()
def test_document_indexing_task_billing_disabled_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful processing when billing is disabled.
@ -554,7 +557,7 @@ class TestDocumentIndexingTasks:
assert updated_document.processing_started_at is not None
def test_document_indexing_task_document_is_paused_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of DocumentIsPausedError from IndexingRunner.
@ -597,7 +600,9 @@ class TestDocumentIndexingTasks:
assert updated_document.processing_started_at is not None
# ==================== NEW TESTS FOR REFACTORED FUNCTIONS ====================
def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_old_document_indexing_task_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test document_indexing_task basic functionality.
@ -619,7 +624,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_normal_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test normal_document_indexing_task basic functionality.
@ -643,7 +648,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_priority_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test priority_document_indexing_task basic functionality.
@ -667,7 +672,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_document_indexing_with_tenant_queue_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test _document_indexing_with_tenant_queue function with no waiting tasks.
@ -717,7 +722,7 @@ class TestDocumentIndexingTasks:
mock_task_func.delay.assert_not_called()
def test_document_indexing_with_tenant_queue_with_waiting_tasks(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis.
@ -776,7 +781,7 @@ class TestDocumentIndexingTasks:
assert len(remaining_tasks) == 1
def test_document_indexing_with_tenant_queue_error_handling(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling in _document_indexing_with_tenant_queue using real Redis.
@ -848,7 +853,7 @@ class TestDocumentIndexingTasks:
assert len(remaining_tasks) == 0
def test_document_indexing_with_tenant_queue_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation in _document_indexing_with_tenant_queue using real Redis.

View File

@ -3,9 +3,10 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus
from tasks.document_indexing_update_task import document_indexing_update_task
@ -33,7 +34,7 @@ class TestDocumentIndexingUpdateTask:
"runner_instance": runner_instance,
}
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
def _create_dataset_document_with_segments(self, db_session_with_containers: Session, *, segment_count: int = 2):
fake = Faker()
# Account and tenant
@ -41,12 +42,12 @@ class TestDocumentIndexingUpdateTask:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=fake.company(), status="normal")
tenant = Tenant(name=fake.company(), status=TenantStatus.NORMAL)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
@ -114,7 +115,7 @@ class TestDocumentIndexingUpdateTask:
return dataset, document, node_ids
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
def test_cleans_segments_and_reindexes(self, db_session_with_containers: Session, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Act
@ -153,7 +154,9 @@ class TestDocumentIndexingUpdateTask:
first = run_docs[0]
assert getattr(first, "id", None) == document.id
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
def test_clean_error_is_logged_and_indexing_continues(
self, db_session_with_containers: Session, mock_external_dependencies
):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Force clean to raise; task should continue to indexing
@ -173,7 +176,7 @@ class TestDocumentIndexingUpdateTask:
)
assert remaining > 0
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
def test_document_not_found_noop(self, db_session_with_containers: Session, mock_external_dependencies):
fake = Faker()
# Act with non-existent document id
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -62,7 +63,7 @@ class TestDuplicateDocumentIndexingTasks:
}
def _create_test_dataset_and_documents(
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3
):
"""
Helper method to create a test dataset and documents for testing.
@ -145,7 +146,11 @@ class TestDuplicateDocumentIndexingTasks:
return dataset, documents
def _create_test_dataset_with_segments(
self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2
self,
db_session_with_containers: Session,
mock_external_service_dependencies,
document_count=3,
segments_per_doc=2,
):
"""
Helper method to create a test dataset with documents and segments.
@ -197,7 +202,7 @@ class TestDuplicateDocumentIndexingTasks:
return dataset, documents, segments
def _create_test_dataset_with_billing_features(
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True
):
"""
Helper method to create a test dataset with billing features configured.
@ -287,7 +292,7 @@ class TestDuplicateDocumentIndexingTasks:
return dataset, documents
def _test_duplicate_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful duplicate document indexing with multiple documents.
@ -329,7 +334,7 @@ class TestDuplicateDocumentIndexingTasks:
assert len(processed_documents) == 3
def _test_duplicate_document_indexing_task_with_segment_cleanup(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test duplicate document indexing with existing segments that need cleanup.
@ -379,7 +384,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def _test_duplicate_document_indexing_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of non-existent dataset.
@ -404,7 +409,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["index_processor"].clean.assert_not_called()
def test_duplicate_document_indexing_task_document_not_found_in_dataset(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling when some documents don't exist in the dataset.
@ -450,7 +455,7 @@ class TestDuplicateDocumentIndexingTasks:
assert len(processed_documents) == 2 # Only existing documents
def _test_duplicate_document_indexing_task_indexing_runner_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of IndexingRunner exceptions.
@ -491,7 +496,7 @@ class TestDuplicateDocumentIndexingTasks:
assert updated_document.processing_started_at is not None
def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test billing validation for sandbox plan batch upload limit.
@ -554,7 +559,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test billing validation for vector space limit.
@ -596,7 +601,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
def test_duplicate_document_indexing_task_with_empty_document_list(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of empty document list.
@ -622,7 +627,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([])
def test_deprecated_duplicate_document_indexing_task_delegates_to_core(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that deprecated duplicate_document_indexing_task delegates to core function.
@ -655,7 +660,7 @@ class TestDuplicateDocumentIndexingTasks:
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_normal_duplicate_document_indexing_task_with_tenant_queue(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test normal_duplicate_document_indexing_task with tenant isolation queue.
@ -698,7 +703,7 @@ class TestDuplicateDocumentIndexingTasks:
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_priority_duplicate_document_indexing_task_with_tenant_queue(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test priority_duplicate_document_indexing_task with tenant isolation queue.
@ -742,7 +747,7 @@ class TestDuplicateDocumentIndexingTasks:
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_tenant_queue_wrapper_processes_next_tasks(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant queue wrapper processes next queued tasks.
@ -789,7 +794,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_not_called()
def test_successful_duplicate_document_indexing(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test successful duplicate document indexing flow."""
self._test_duplicate_document_indexing_task_success(
@ -797,7 +802,7 @@ class TestDuplicateDocumentIndexingTasks:
)
def test_duplicate_document_indexing_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test duplicate document indexing when dataset is not found."""
self._test_duplicate_document_indexing_task_dataset_not_found(
@ -805,7 +810,7 @@ class TestDuplicateDocumentIndexingTasks:
)
def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test duplicate document indexing with billing enabled and sandbox plan."""
self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
@ -813,7 +818,7 @@ class TestDuplicateDocumentIndexingTasks:
)
def test_duplicate_document_indexing_with_billing_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test duplicate document indexing when billing limit is exceeded."""
self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
@ -821,7 +826,7 @@ class TestDuplicateDocumentIndexingTasks:
)
def test_duplicate_document_indexing_runner_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test duplicate document indexing when IndexingRunner raises an error."""
self._test_duplicate_document_indexing_task_indexing_runner_exception(
@ -829,7 +834,7 @@ class TestDuplicateDocumentIndexingTasks:
)
def _test_duplicate_document_indexing_task_document_is_paused(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test duplicate document indexing when document is paused."""
# Arrange
@ -860,7 +865,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_duplicate_document_indexing_document_is_paused(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test duplicate document indexing when document is paused."""
self._test_duplicate_document_indexing_task_document_is_paused(
@ -868,7 +873,7 @@ class TestDuplicateDocumentIndexingTasks:
)
def test_duplicate_document_indexing_cleans_old_segments(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""Test that duplicate document indexing cleans old segments."""
self._test_duplicate_document_indexing_task_with_segment_cleanup(

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -29,7 +30,7 @@ class TestMailChangeMailTask:
"get_email_i18n_service": mock_get_email_i18n_service,
}
def _create_test_account(self, db_session_with_containers):
def _create_test_account(self, db_session_with_containers: Session):
"""
Helper method to create a test account for testing.
@ -72,7 +73,7 @@ class TestMailChangeMailTask:
return account
def test_send_change_mail_task_success_old_email_phase(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful change email task execution for old_email phase.
@ -103,7 +104,7 @@ class TestMailChangeMailTask:
)
def test_send_change_mail_task_success_new_email_phase(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful change email task execution for new_email phase.
@ -134,7 +135,7 @@ class TestMailChangeMailTask:
)
def test_send_change_mail_task_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test change email task when mail service is not initialized.
@ -159,7 +160,7 @@ class TestMailChangeMailTask:
mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_not_called()
def test_send_change_mail_task_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test change email task when email service raises an exception.
@ -191,7 +192,7 @@ class TestMailChangeMailTask:
)
def test_send_change_mail_completed_notification_task_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful change email completed notification task execution.
@ -224,7 +225,7 @@ class TestMailChangeMailTask:
)
def test_send_change_mail_completed_notification_task_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test change email completed notification task when mail service is not initialized.
@ -247,7 +248,7 @@ class TestMailChangeMailTask:
mock_external_service_dependencies["email_i18n_service"].send_email.assert_not_called()
def test_send_change_mail_completed_notification_task_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test change email completed notification task when email service raises an exception.

View File

@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete
from sqlalchemy.orm import Session
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -37,7 +38,7 @@ class TestSendEmailCodeLoginMailTask:
"""
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database before each test to ensure isolation."""
from extensions.ext_redis import redis_client
@ -71,7 +72,7 @@ class TestSendEmailCodeLoginMailTask:
"email_service_instance": mock_email_service_instance,
}
def _create_test_account(self, db_session_with_containers, fake=None):
def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None):
"""
Helper method to create a test account for testing.
@ -98,7 +99,7 @@ class TestSendEmailCodeLoginMailTask:
return account
def _create_test_tenant_and_account(self, db_session_with_containers, fake=None):
def _create_test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker | None = None):
"""
Helper method to create a test tenant and account for testing.
@ -138,7 +139,7 @@ class TestSendEmailCodeLoginMailTask:
return account, tenant
def test_send_email_code_login_mail_task_success_english(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful email code login mail sending in English.
@ -182,7 +183,7 @@ class TestSendEmailCodeLoginMailTask:
)
def test_send_email_code_login_mail_task_success_chinese(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful email code login mail sending in Chinese.
@ -221,7 +222,7 @@ class TestSendEmailCodeLoginMailTask:
)
def test_send_email_code_login_mail_task_success_multiple_languages(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful email code login mail sending with multiple languages.
@ -261,7 +262,7 @@ class TestSendEmailCodeLoginMailTask:
assert call_args[1]["template_context"]["code"] == test_codes[i]
def test_send_email_code_login_mail_task_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email code login mail task when mail service is not initialized.
@ -299,7 +300,7 @@ class TestSendEmailCodeLoginMailTask:
mock_email_service_instance.send_email.assert_not_called()
def test_send_email_code_login_mail_task_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email code login mail task when email service raises an exception.
@ -346,7 +347,7 @@ class TestSendEmailCodeLoginMailTask:
)
def test_send_email_code_login_mail_task_invalid_parameters(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email code login mail task with invalid parameters.
@ -388,7 +389,7 @@ class TestSendEmailCodeLoginMailTask:
mock_email_service_instance.send_email.assert_called_once()
def test_send_email_code_login_mail_task_edge_cases(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email code login mail task with edge cases and boundary conditions.
@ -451,7 +452,7 @@ class TestSendEmailCodeLoginMailTask:
)
def test_send_email_code_login_mail_task_database_integration(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email code login mail task with database integration.
@ -497,7 +498,7 @@ class TestSendEmailCodeLoginMailTask:
assert account.status == "active"
def test_send_email_code_login_mail_task_redis_integration(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email code login mail task with Redis integration.
@ -541,7 +542,7 @@ class TestSendEmailCodeLoginMailTask:
redis_client.delete(cache_key)
def test_send_email_code_login_mail_task_error_handling_comprehensive(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test comprehensive error handling for email code login mail task.

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from sqlalchemy import delete
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.app_config.entities import WorkflowUIBasedAppConfig
@ -172,7 +173,9 @@ def _create_workflow_pause_state(
db_session_with_containers.commit()
def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers):
def test_dispatch_human_input_email_task_integration(
monkeypatch: pytest.MonkeyPatch, db_session_with_containers: Session
):
tenant, account = _create_workspace_member(db_session_with_containers)
workflow_run_id = str(uuid.uuid4())
workflow_id = str(uuid.uuid4())

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from tasks.mail_inner_task import send_inner_email_task
@ -51,7 +52,7 @@ class TestMailInnerTask:
},
}
def test_send_inner_email_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_send_inner_email_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful email sending with valid data.
@ -90,7 +91,9 @@ class TestMailInnerTask:
html_content="<html>Test email content</html>",
)
def test_send_inner_email_single_recipient(self, db_session_with_containers, mock_external_service_dependencies):
def test_send_inner_email_single_recipient(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email sending with single recipient.
@ -126,7 +129,9 @@ class TestMailInnerTask:
html_content="<html>Test email content</html>",
)
def test_send_inner_email_empty_substitutions(self, db_session_with_containers, mock_external_service_dependencies):
def test_send_inner_email_empty_substitutions(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email sending with empty substitutions.
@ -163,7 +168,7 @@ class TestMailInnerTask:
)
def test_send_inner_email_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email sending when mail service is not initialized.
@ -193,7 +198,7 @@ class TestMailInnerTask:
mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called()
def test_send_inner_email_template_rendering_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email sending when template rendering fails.
@ -222,7 +227,9 @@ class TestMailInnerTask:
# Verify no email service calls due to exception
mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called()
def test_send_inner_email_service_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_send_inner_email_service_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email sending when email service fails.

View File

@ -18,6 +18,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from extensions.ext_redis import redis_client
from libs.email_i18n import EmailType
@ -42,7 +43,7 @@ class TestMailInviteMemberTask:
"""
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database before each test to ensure isolation."""
# Clear all test data
db_session_with_containers.execute(delete(TenantAccountJoin))
@ -78,7 +79,7 @@ class TestMailInviteMemberTask:
"config": mock_config,
}
def _create_test_account_and_tenant(self, db_session_with_containers):
def _create_test_account_and_tenant(self, db_session_with_containers: Session):
"""
Helper method to create a test account and tenant for testing.
@ -147,7 +148,7 @@ class TestMailInviteMemberTask:
redis_client.setex(cache_key, 24 * 60 * 60, json.dumps(invitation_data)) # 24 hours
return token
def _create_pending_account_for_invitation(self, db_session_with_containers, email, tenant):
def _create_pending_account_for_invitation(self, db_session_with_containers: Session, email, tenant):
"""
Helper method to create a pending account for invitation testing.
@ -185,7 +186,9 @@ class TestMailInviteMemberTask:
return account
def test_send_invite_member_mail_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_send_invite_member_mail_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful invitation email sending with all parameters.
@ -231,7 +234,7 @@ class TestMailInviteMemberTask:
assert template_context["url"] == f"https://console.dify.ai/activate?token={token}"
def test_send_invite_member_mail_different_languages(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test invitation email sending with different language codes.
@ -263,7 +266,7 @@ class TestMailInviteMemberTask:
assert call_args[1]["language_code"] == language
def test_send_invite_member_mail_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test behavior when mail service is not initialized.
@ -292,7 +295,7 @@ class TestMailInviteMemberTask:
mock_email_service.send_email.assert_not_called()
def test_send_invite_member_mail_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when email service raises an exception.
@ -322,7 +325,7 @@ class TestMailInviteMemberTask:
assert "Send invite member mail to %s failed" in error_call
def test_send_invite_member_mail_template_context_validation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test template context contains all required fields for email rendering.
@ -368,7 +371,7 @@ class TestMailInviteMemberTask:
assert template_context["url"] == f"https://console.dify.ai/activate?token={token}"
def test_send_invite_member_mail_integration_with_redis_token(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test integration with Redis token validation.
@ -407,7 +410,7 @@ class TestMailInviteMemberTask:
assert invitation_data["workspace_id"] == tenant.id
def test_send_invite_member_mail_with_special_characters(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email sending with special characters in names and workspace names.
@ -449,7 +452,7 @@ class TestMailInviteMemberTask:
assert template_context["workspace_name"] == workspace_name
def test_send_invite_member_mail_real_database_integration(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test real database integration with actual invitation flow.
@ -501,7 +504,7 @@ class TestMailInviteMemberTask:
assert tenant_join.role == TenantAccountRole.NORMAL
def test_send_invite_member_mail_token_lifecycle_management(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test token lifecycle management and validation.

View File

@ -11,6 +11,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -44,7 +45,7 @@ class TestMailOwnerTransferTask:
"get_email_service": mock_get_email_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers):
def _create_test_account_and_tenant(self, db_session_with_containers: Session):
"""
Helper method to create test account and tenant for testing.
@ -86,7 +87,9 @@ class TestMailOwnerTransferTask:
return account, tenant
def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies):
def test_send_owner_transfer_confirm_task_success(
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test successful owner transfer confirmation email sending.
@ -127,7 +130,7 @@ class TestMailOwnerTransferTask:
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
def test_send_owner_transfer_confirm_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test owner transfer confirmation email when mail service is not initialized.
@ -158,7 +161,7 @@ class TestMailOwnerTransferTask:
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_owner_transfer_confirm_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test exception handling in owner transfer confirmation email.
@ -192,7 +195,7 @@ class TestMailOwnerTransferTask:
mock_mail_dependencies["email_service"].send_email.assert_called_once()
def test_send_old_owner_transfer_notify_email_task_success(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test successful old owner transfer notification email sending.
@ -234,7 +237,7 @@ class TestMailOwnerTransferTask:
assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email
def test_send_old_owner_transfer_notify_email_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test old owner transfer notification email when mail service is not initialized.
@ -265,7 +268,7 @@ class TestMailOwnerTransferTask:
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_old_owner_transfer_notify_email_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test exception handling in old owner transfer notification email.
@ -299,7 +302,7 @@ class TestMailOwnerTransferTask:
mock_mail_dependencies["email_service"].send_email.assert_called_once()
def test_send_new_owner_transfer_notify_email_task_success(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test successful new owner transfer notification email sending.
@ -338,7 +341,7 @@ class TestMailOwnerTransferTask:
assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace
def test_send_new_owner_transfer_notify_email_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test new owner transfer notification email when mail service is not initialized.
@ -367,7 +370,7 @@ class TestMailOwnerTransferTask:
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_new_owner_transfer_notify_email_task_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""
Test exception handling in new owner transfer notification email.

View File

@ -9,6 +9,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from libs.email_i18n import EmailType
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
@ -35,7 +36,7 @@ class TestMailRegisterTask:
"get_email_service": mock_get_email_service,
}
def test_send_email_register_mail_task_success(self, db_session_with_containers, mock_mail_dependencies):
def test_send_email_register_mail_task_success(self, db_session_with_containers: Session, mock_mail_dependencies):
"""Test successful email registration mail sending."""
fake = Faker()
language = "en-US"
@ -56,7 +57,7 @@ class TestMailRegisterTask:
)
def test_send_email_register_mail_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""Test email registration task when mail service is not initialized."""
mock_mail_dependencies["mail"].is_inited.return_value = False
@ -66,7 +67,9 @@ class TestMailRegisterTask:
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_email_register_mail_task_exception_handling(self, db_session_with_containers, mock_mail_dependencies):
def test_send_email_register_mail_task_exception_handling(
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""Test email registration task exception handling."""
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
@ -79,7 +82,7 @@ class TestMailRegisterTask:
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)
def test_send_email_register_mail_task_when_account_exist_success(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""Test successful email registration mail sending when account exists."""
fake = Faker()
@ -105,7 +108,7 @@ class TestMailRegisterTask:
)
def test_send_email_register_mail_task_when_account_exist_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""Test account exist email task when mail service is not initialized."""
mock_mail_dependencies["mail"].is_inited.return_value = False
@ -118,7 +121,7 @@ class TestMailRegisterTask:
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_email_register_mail_task_when_account_exist_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
self, db_session_with_containers: Session, mock_mail_dependencies
):
"""Test account exist email task exception handling."""
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")

View File

@ -4,12 +4,13 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from flask import Flask
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
from models.dataset import Pipeline
from models.workflow import Workflow
from tasks.rag_pipeline.priority_rag_pipeline_run_task import (
@ -69,14 +70,14 @@ class TestRagPipelineRunTasks:
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
status=TenantStatus.NORMAL,
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
@ -725,7 +726,7 @@ class TestRagPipelineRunTasks:
assert queue1._task_key != queue2._task_key
def test_run_single_rag_pipeline_task_success(
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask
):
"""
Test successful run_single_rag_pipeline_task execution.
@ -760,7 +761,7 @@ class TestRagPipelineRunTasks:
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_run_single_rag_pipeline_task_entity_validation_error(
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask
):
"""
Test run_single_rag_pipeline_task with invalid entity data.
@ -805,7 +806,7 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called()
def test_run_single_rag_pipeline_task_database_entity_not_found(
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask
):
"""
Test run_single_rag_pipeline_task with non-existent database entities.

View File

@ -3,6 +3,7 @@ from unittest.mock import ANY, call, patch
import pytest
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from extensions.storage.storage_type import StorageType
@ -117,7 +118,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id:
class TestDeleteDraftVariablesBatch:
def test_delete_draft_variables_batch_success(self, db_session_with_containers):
def test_delete_draft_variables_batch_success(self, db_session_with_containers: Session):
"""Test successful deletion of draft variables in batches."""
_, app1 = _create_tenant_and_app(db_session_with_containers)
_, app2 = _create_tenant_and_app(db_session_with_containers)
@ -137,7 +138,7 @@ class TestDeleteDraftVariablesBatch:
assert app1_remaining_count == 0
assert app2_remaining_count == 100
def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers):
def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers: Session):
"""Test deletion when no draft variables exist for the app."""
result = delete_draft_variables_batch(str(uuid.uuid4()), 1000)
@ -176,7 +177,7 @@ class TestDeleteDraftVariableOffloadData:
"""Test the Offload data cleanup functionality."""
@patch("extensions.ext_storage.storage")
def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers):
def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers: Session):
"""Test successful deletion of offload data."""
tenant, app = _create_tenant_and_app(db_session_with_containers)
offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=3)

View File

@ -1,12 +1,14 @@
from pathlib import Path
import pytest
from extensions.storage.opendal_storage import OpenDALStorage
class TestOpenDALFsDefaultRoot:
"""Test that OpenDALStorage with scheme='fs' works correctly when no root is provided."""
def test_fs_without_root_uses_default(self, tmp_path, monkeypatch):
def test_fs_without_root_uses_default(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""When no root is specified, the default 'storage' should be used and passed to the Operator."""
# Change to tmp_path so the default "storage" dir is created there
monkeypatch.chdir(tmp_path)
@ -25,7 +27,7 @@ class TestOpenDALFsDefaultRoot:
# Cleanup
storage.delete("test_default_root.txt")
def test_fs_with_explicit_root(self, tmp_path):
def test_fs_with_explicit_root(self, tmp_path: Path):
"""When root is explicitly provided, it should be used."""
custom_root = str(tmp_path / "custom_storage")
storage = OpenDALStorage(scheme="fs", root=custom_root)
@ -38,7 +40,7 @@ class TestOpenDALFsDefaultRoot:
# Cleanup
storage.delete("test_explicit_root.txt")
def test_fs_with_env_var_root(self, tmp_path, monkeypatch):
def test_fs_with_env_var_root(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs."""
env_root = str(tmp_path / "env_storage")
monkeypatch.setenv("OPENDAL_FS_ROOT", env_root)

View File

@ -175,7 +175,7 @@ class TestWorkflowPauseIntegration:
"""Comprehensive integration tests for workflow pause functionality."""
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers):
def setup_test_data(self, db_session_with_containers: Session):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account

View File

@ -1,12 +1,14 @@
from textwrap import dedent
from flask import Flask
from .test_utils import CodeExecutorTestMixin
class TestJavaScriptCodeExecutor(CodeExecutorTestMixin):
"""Test class for JavaScript code executor functionality."""
def test_javascript_plain(self, flask_app_with_containers):
def test_javascript_plain(self, flask_app_with_containers: Flask):
"""Test basic JavaScript code execution with console.log output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
@ -14,7 +16,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin):
result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
assert result_message == "Hello World\n"
def test_javascript_json(self, flask_app_with_containers):
def test_javascript_json(self, flask_app_with_containers: Flask):
"""Test JavaScript code execution with JSON output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
@ -25,7 +27,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin):
result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code)
assert result == '{"Hello":"World"}\n'
def test_javascript_with_code_template(self, flask_app_with_containers):
def test_javascript_with_code_template(self, flask_app_with_containers: Flask):
"""Test JavaScript workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
JavascriptCodeProvider, _ = self.javascript_imports
@ -37,7 +39,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin):
)
assert result == {"result": "HelloWorld"}
def test_javascript_get_runner_script(self, flask_app_with_containers):
def test_javascript_get_runner_script(self, flask_app_with_containers: Flask):
"""Test JavaScript template transformer runner script generation"""
_, NodeJsTemplateTransformer = self.javascript_imports

View File

@ -1,12 +1,14 @@
import base64
from flask import Flask
from .test_utils import CodeExecutorTestMixin
class TestJinja2CodeExecutor(CodeExecutorTestMixin):
"""Test class for Jinja2 code executor functionality."""
def test_jinja2(self, flask_app_with_containers):
def test_jinja2(self, flask_app_with_containers: Flask):
"""Test basic Jinja2 template execution with variable substitution"""
CodeExecutor, CodeLanguage = self.code_executor_imports
_, Jinja2TemplateTransformer = self.jinja2_imports
@ -25,7 +27,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
)
assert result == "<<RESULT>>Hello World<<RESULT>>\n"
def test_jinja2_with_code_template(self, flask_app_with_containers):
def test_jinja2_with_code_template(self, flask_app_with_containers: Flask):
"""Test Jinja2 workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
@ -34,7 +36,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
)
assert result == {"result": "Hello World"}
def test_jinja2_get_runner_script(self, flask_app_with_containers):
def test_jinja2_get_runner_script(self, flask_app_with_containers: Flask):
"""Test Jinja2 template transformer runner script generation"""
_, Jinja2TemplateTransformer = self.jinja2_imports
@ -43,7 +45,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin):
assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1
assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2
def test_jinja2_template_with_special_characters(self, flask_app_with_containers):
def test_jinja2_template_with_special_characters(self, flask_app_with_containers: Flask):
"""
Test that templates with special characters (quotes, newlines) render correctly.
This is a regression test for issue #26818 where textarea pre-fill values

View File

@ -1,12 +1,14 @@
from textwrap import dedent
from flask import Flask
from .test_utils import CodeExecutorTestMixin
class TestPython3CodeExecutor(CodeExecutorTestMixin):
"""Test class for Python3 code executor functionality."""
def test_python3_plain(self, flask_app_with_containers):
def test_python3_plain(self, flask_app_with_containers: Flask):
"""Test basic Python3 code execution with print output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
@ -14,7 +16,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin):
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
assert result == "Hello World\n"
def test_python3_json(self, flask_app_with_containers):
def test_python3_json(self, flask_app_with_containers: Flask):
"""Test Python3 code execution with JSON output"""
CodeExecutor, CodeLanguage = self.code_executor_imports
@ -25,7 +27,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin):
result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code)
assert result == '{"Hello": "World"}\n'
def test_python3_with_code_template(self, flask_app_with_containers):
def test_python3_with_code_template(self, flask_app_with_containers: Flask):
"""Test Python3 workflow code template execution with inputs"""
CodeExecutor, CodeLanguage = self.code_executor_imports
Python3CodeProvider, _ = self.python3_imports
@ -37,7 +39,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin):
)
assert result == {"result": "HelloWorld"}
def test_python3_get_runner_script(self, flask_app_with_containers):
def test_python3_get_runner_script(self, flask_app_with_containers: Flask):
"""Test Python3 template transformer runner script generation"""
_, Python3TemplateTransformer = self.python3_imports

View File

@ -67,7 +67,7 @@ class TestActivateCheckApi:
assert response["data"]["email"] == "invitee@example.com"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_invalid_invitation_token(self, mock_get_invitation, app):
def test_check_invalid_invitation_token(self, mock_get_invitation, app: Flask):
"""
Test checking invalid invitation token.
@ -227,7 +227,7 @@ class TestActivateApi:
mock_db.session.commit.assert_called_once()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_activation_with_invalid_token(self, mock_get_invitation, app):
def test_activation_with_invalid_token(self, mock_get_invitation, app: Flask):
"""
Test account activation with invalid token.

View File

@ -140,7 +140,7 @@ class TestEmailCodeLoginSendEmailApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app):
def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app: Flask):
"""
Test email code sending blocked by IP rate limit.
@ -160,7 +160,7 @@ class TestEmailCodeLoginSendEmailApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit")
@patch("controllers.console.auth.login.AccountService.get_user_through_email")
def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app):
def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app: Flask):
"""
Test email code sending to frozen account.
@ -353,7 +353,7 @@ class TestEmailCodeLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app):
def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app: Flask):
"""
Test email code login with invalid token.
@ -375,7 +375,7 @@ class TestEmailCodeLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app):
def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app: Flask):
"""
Test email code login with mismatched email.
@ -397,7 +397,7 @@ class TestEmailCodeLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app):
def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app: Flask):
"""
Test email code login with incorrect code.

View File

@ -9,7 +9,7 @@ This module tests the core authentication endpoints including:
"""
import base64
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from flask import Flask
@ -52,12 +52,12 @@ class TestLoginApi:
return app
@pytest.fixture
def api(self, app):
def api(self, app: Flask):
"""Create Flask-RESTX API instance."""
return Api(app)
@pytest.fixture
def client(self, app, api):
def client(self, app: Flask, api: Api):
"""Create test client."""
api.add_resource(LoginApi, "/login")
return app.test_client()
@ -97,7 +97,7 @@ class TestLoginApi:
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -141,14 +141,14 @@ class TestLoginApi:
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_successful_login_with_valid_invitation(
self,
mock_reset_rate_limit,
mock_reset_rate_limit: Mock,
mock_login,
mock_get_tenants,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -188,7 +188,7 @@ class TestLoginApi:
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask):
"""
Test login rejection when rate limit is exceeded.
@ -216,7 +216,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True)
@patch("controllers.console.auth.login.BillingService.is_email_in_freeze")
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app):
def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask):
"""
Test login rejection for frozen accounts.
@ -253,7 +253,7 @@ class TestLoginApi:
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
app: Flask,
):
"""
Test login failure with invalid credentials.
@ -290,7 +290,7 @@ class TestLoginApi:
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
def test_login_fails_for_banned_account(
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask
):
"""
Test login rejection for banned accounts.
@ -328,14 +328,14 @@ class TestLoginApi:
@patch("controllers.console.auth.login.FeatureService.get_system_features")
def test_login_fails_when_no_workspace_and_limit_exceeded(
self,
mock_get_features,
mock_get_tenants,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
mock_account,
mock_get_features: MagicMock,
mock_get_tenants: MagicMock,
mock_authenticate: MagicMock,
mock_get_invitation: MagicMock,
mock_is_rate_limit: MagicMock,
mock_db: MagicMock,
app: Flask,
mock_account: MagicMock,
):
"""
Test login failure when user has no workspace and workspace limit exceeded.
@ -367,7 +367,7 @@ class TestLoginApi:
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask):
"""
Test login failure when invitation email doesn't match login email.
@ -491,7 +491,7 @@ class TestLogoutApi:
@patch("controllers.console.auth.login.AccountService.logout")
@patch("controllers.console.auth.login.flask_login.logout_user")
def test_successful_logout(
self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app, mock_account
self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app: Flask, mock_account
):
"""
Test successful logout flow.
@ -518,7 +518,7 @@ class TestLogoutApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.current_account_with_tenant")
@patch("controllers.console.auth.login.flask_login")
def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app):
def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app: Flask):
"""
Test logout for anonymous (not logged in) user.

View File

@ -28,12 +28,12 @@ class TestRefreshTokenApi:
return app
@pytest.fixture
def api(self, app):
def api(self, app: Flask):
"""Create Flask-RESTX API instance."""
return Api(app)
@pytest.fixture
def client(self, app, api):
def client(self, app: Flask, api: Api):
"""Create test client."""
api.add_resource(RefreshTokenApi, "/refresh-token")
return app.test_client()

View File

@ -49,7 +49,7 @@ class TestPartnerTenants:
mock_csrf.return_value = None
yield {"db": mock_db, "csrf": mock_csrf}
def test_put_success(self, app, mock_account, mock_billing_service, mock_decorators):
def test_put_success(self, app: Flask, mock_account, mock_billing_service, mock_decorators):
"""Test successful partner tenants bindings sync."""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
@ -79,7 +79,7 @@ class TestPartnerTenants:
mock_account.id, "partner-key-123", click_id
)
def test_put_invalid_partner_key_base64(self, app, mock_account, mock_billing_service, mock_decorators):
def test_put_invalid_partner_key_base64(self, app: Flask, mock_account, mock_billing_service, mock_decorators):
"""Test that invalid base64 partner_key raises BadRequest."""
# Arrange
invalid_partner_key = "invalid-base64-!@#$"
@ -104,7 +104,7 @@ class TestPartnerTenants:
resource.put(invalid_partner_key)
assert "Invalid partner_key" in str(exc_info.value)
def test_put_missing_click_id(self, app, mock_account, mock_billing_service, mock_decorators):
def test_put_missing_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators):
"""Test that missing click_id raises BadRequest."""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
@ -128,7 +128,9 @@ class TestPartnerTenants:
with pytest.raises(BadRequest):
resource.put(partner_key_encoded)
def test_put_billing_service_json_decode_error(self, app, mock_account, mock_billing_service, mock_decorators):
def test_put_billing_service_json_decode_error(
self, app: Flask, mock_account, mock_billing_service, mock_decorators
):
"""Test handling of billing service JSON decode error.
When billing service returns non-200 status code with invalid JSON response,
@ -174,7 +176,7 @@ class TestPartnerTenants:
assert isinstance(exc_info.value, json.JSONDecodeError)
assert "Expecting value" in str(exc_info.value)
def test_put_empty_click_id(self, app, mock_account, mock_billing_service, mock_decorators):
def test_put_empty_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators):
"""Test that empty click_id raises BadRequest."""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")
@ -199,7 +201,7 @@ class TestPartnerTenants:
resource.put(partner_key_encoded)
assert "Invalid partner information" in str(exc_info.value)
def test_put_empty_partner_key_after_decode(self, app, mock_account, mock_billing_service, mock_decorators):
def test_put_empty_partner_key_after_decode(self, app: Flask, mock_account, mock_billing_service, mock_decorators):
"""Test that empty partner_key after decode raises BadRequest."""
# Arrange
# Base64 encode an empty string
@ -225,7 +227,7 @@ class TestPartnerTenants:
resource.put(empty_partner_key_encoded)
assert "Invalid partner information" in str(exc_info.value)
def test_put_empty_user_id(self, app, mock_account, mock_billing_service, mock_decorators):
def test_put_empty_user_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators):
"""Test that empty user id raises BadRequest."""
# Arrange
partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8")

View File

@ -8,10 +8,8 @@ from werkzeug.exceptions import Forbidden
import controllers.console.tag.tags as module
from controllers.console import console_ns
from controllers.console.tag.tags import (
DeprecatedTagBindingCreateApi,
DeprecatedTagBindingRemoveApi,
TagBindingCollectionApi,
TagBindingItemApi,
TagBindingRemoveApi,
TagListApi,
TagUpdateDeleteApi,
)
@ -249,39 +247,13 @@ class TestTagBindingCollectionApi:
method(api)
class TestDeprecatedTagBindingCreateApi:
def test_create_success(self, app, admin_user, payload_patch):
api = DeprecatedTagBindingCreateApi()
class TestTagBindingRemoveApi:
def test_remove_success(self, app, admin_user, payload_patch):
api = TagBindingRemoveApi()
method = unwrap(api.post)
payload = {
"tag_ids": ["tag-1"],
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
save_mock.assert_called_once()
assert status == 200
assert result["result"] == "success"
class TestTagBindingItemApi:
def test_delete_success(self, app, admin_user, payload_patch):
api = TagBindingItemApi()
method = unwrap(api.delete)
payload = {
"tag_ids": ["tag-1", "tag-2"],
"target_id": "target-1",
"type": "knowledge",
}
@ -295,57 +267,16 @@ class TestTagBindingItemApi:
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api, "tag-1")
result, status = method(api)
delete_mock.assert_called_once()
delete_payload = delete_mock.call_args.args[0]
assert delete_payload.tag_id == "tag-1"
assert delete_payload.target_id == "target-1"
assert delete_payload.type == TagType.KNOWLEDGE
assert status == 200
assert result["result"] == "success"
def test_delete_forbidden(self, app, readonly_user):
api = TagBindingItemApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
class TestDeprecatedTagBindingRemoveApi:
def test_remove_success(self, app, admin_user, payload_patch):
api = DeprecatedTagBindingRemoveApi()
method = unwrap(api.post)
payload = {
"tag_id": "tag-1",
"target_id": "target-1",
"type": "knowledge",
}
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api)
delete_mock.assert_called_once()
assert delete_payload.tag_ids == ["tag-1", "tag-2"]
assert status == 200
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
api = DeprecatedTagBindingRemoveApi()
api = TagBindingRemoveApi()
method = unwrap(api.post)
with app.test_request_context("/", json={}):
@ -371,32 +302,30 @@ class TestTagResponseModel:
class TestTagBindingRouteMetadata:
def test_legacy_write_routes_are_marked_deprecated(self):
assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True
def test_write_routes_are_not_deprecated(self):
assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True
assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True
assert TagBindingRemoveApi.post.__apidoc__.get("deprecated") is not True
def test_write_routes_have_stable_operation_ids(self):
assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding"
assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding"
assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated"
assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated"
assert TagBindingRemoveApi.post.__apidoc__["id"] == "remove_tag_bindings"
def test_canonical_and_legacy_write_routes_are_registered(self):
def test_write_routes_are_registered(self):
route_map = {
resource.__name__: urls
for resource, urls, _route_doc, _kwargs in console_ns.resources
if resource.__name__
in {
"TagBindingCollectionApi",
"TagBindingItemApi",
"DeprecatedTagBindingCreateApi",
"DeprecatedTagBindingRemoveApi",
"TagBindingRemoveApi",
}
}
assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",)
assert route_map["TagBindingItemApi"] == ("/tag-bindings/<uuid:id>",)
assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",)
assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",)
assert route_map["TagBindingRemoveApi"] == ("/tag-bindings/remove",)
def test_legacy_write_routes_are_not_registered(self):
urls = {url for _resource, resource_urls, _route_doc, _kwargs in console_ns.resources for url in resource_urls}
assert "/tag-bindings/create" not in urls
assert "/tag-bindings/<uuid:id>" not in urls

View File

@ -17,7 +17,7 @@ def app():
return app
def test_parse_openapi_to_tool_bundle_operation_id(app):
def test_parse_openapi_to_tool_bundle_operation_id(app: Flask):
openapi = {
"openapi": "3.0.0",
"info": {"title": "Simple API", "version": "1.0.0"},
@ -63,7 +63,7 @@ def test_parse_openapi_to_tool_bundle_operation_id(app):
assert tool_bundles[2].operation_id == "createResource"
def test_parse_openapi_to_tool_bundle_properties_all_of(app):
def test_parse_openapi_to_tool_bundle_properties_all_of(app: Flask):
openapi = {
"openapi": "3.0.0",
"info": {"title": "Simple API", "version": "1.0.0"},
@ -118,7 +118,7 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app):
# assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"}
def test_parse_openapi_to_tool_bundle_default_value_type_casting(app):
def test_parse_openapi_to_tool_bundle_default_value_type_casting(app: Flask):
"""
Test that default values are properly cast to match parameter types.
This addresses the issue where array default values like [] cause validation errors

View File

@ -146,7 +146,7 @@ class ControllerApiTestDataFactory:
return app
@staticmethod
def create_api_instance(app):
def create_api_instance(app: Flask):
"""
Create a Flask-RESTX API instance.
@ -160,7 +160,12 @@ class ControllerApiTestDataFactory:
return api
@staticmethod
def create_test_client(app, api, resource_class, route):
def create_test_client(
app: Flask,
api: Api,
resource_class: type,
route: str,
):
"""
Create a Flask test client with a resource registered.
@ -302,7 +307,7 @@ class TestDatasetListApi:
return ControllerApiTestDataFactory.create_flask_app()
@pytest.fixture
def api(self, app):
def api(self, app: Flask):
"""
Create Flask-RESTX API instance.
@ -311,7 +316,7 @@ class TestDatasetListApi:
return ControllerApiTestDataFactory.create_api_instance(app)
@pytest.fixture
def client(self, app, api):
def client(self, app: Flask, api: Api):
"""
Create test client with DatasetListApi registered.
@ -472,12 +477,12 @@ class TestDatasetApiGet:
return ControllerApiTestDataFactory.create_flask_app()
@pytest.fixture
def api(self, app):
def api(self, app: Flask):
"""Create Flask-RESTX API instance."""
return ControllerApiTestDataFactory.create_api_instance(app)
@pytest.fixture
def client(self, app, api):
def client(self, app: Flask, api: Api):
"""Create test client with DatasetApi registered."""
return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets/<uuid:dataset_id>")
@ -588,12 +593,12 @@ class TestDatasetApiCreate:
return ControllerApiTestDataFactory.create_flask_app()
@pytest.fixture
def api(self, app):
def api(self, app: Flask):
"""Create Flask-RESTX API instance."""
return ControllerApiTestDataFactory.create_api_instance(app)
@pytest.fixture
def client(self, app, api):
def client(self, app: Flask, api: Api):
"""Create test client with DatasetApi registered."""
return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets")
@ -681,12 +686,12 @@ class TestHitTestingApi:
return ControllerApiTestDataFactory.create_flask_app()
@pytest.fixture
def api(self, app):
def api(self, app: Flask):
"""Create Flask-RESTX API instance."""
return ControllerApiTestDataFactory.create_api_instance(app)
@pytest.fixture
def client(self, app, api):
def client(self, app: Flask, api: Api):
"""Create test client with HitTestingApi registered."""
return ControllerApiTestDataFactory.create_test_client(
app, api, HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing"
@ -799,12 +804,12 @@ class TestExternalDatasetApi:
return ControllerApiTestDataFactory.create_flask_app()
@pytest.fixture
def api(self, app):
def api(self, app: Flask):
"""Create Flask-RESTX API instance."""
return ControllerApiTestDataFactory.create_api_instance(app)
@pytest.fixture
def client_list(self, app, api):
def client_list(self, app: Flask, api: Api):
"""Create test client for external knowledge API list endpoint."""
return ControllerApiTestDataFactory.create_test_client(
app, api, ExternalApiTemplateListApi, "/datasets/external-knowledge-api"

36
api/uv.lock generated
View File

@ -1629,7 +1629,7 @@ dev = [
{ name = "lxml-stubs", specifier = ">=0.5.1" },
{ name = "mypy", specifier = ">=1.20.2" },
{ name = "pandas-stubs", specifier = ">=3.0.0" },
{ name = "pyrefly", specifier = ">=0.62.0" },
{ name = "pyrefly", specifier = ">=0.64.0" },
{ name = "pytest", specifier = ">=9.0.3" },
{ name = "pytest-benchmark", specifier = ">=5.2.3" },
{ name = "pytest-cov", specifier = ">=7.1.0" },
@ -2657,14 +2657,14 @@ wheels = [
[[package]]
name = "gitpython"
version = "3.1.47"
version = "3.1.49"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "gitdb" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c1/bd/50db468e9b1310529a19fce651b3b0e753b5c07954d486cba31bbee9a5d5/gitpython-3.1.47.tar.gz", hash = "sha256:dba27f922bd2b42cb54c87a8ab3cb6beb6bf07f3d564e21ac848913a05a8a3cd", size = 216978, upload-time = "2026-04-22T02:44:44.059Z" }
sdist = { url = "https://files.pythonhosted.org/packages/e1/63/210aaa302d6a0a78daa67c5c15bbac2cad361722841278b0209b6da20855/gitpython-3.1.49.tar.gz", hash = "sha256:42f9399c9eb33fc581014bedd76049dfbaf6375aa2a5754575966387280315e1", size = 219367, upload-time = "2026-04-29T00:31:20.478Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f2/c5/a1bc0996af85757903cf2bf444a7824e68e0035ce63fb41d6f76f9def68b/gitpython-3.1.47-py3-none-any.whl", hash = "sha256:489f590edfd6d20571b2c0e72c6a6ac6915ee8b8cd04572330e3842207a78905", size = 209547, upload-time = "2026-04-22T02:44:41.271Z" },
{ url = "https://files.pythonhosted.org/packages/fd/6f/b842bfa6f21d6f87c57f9abf7194225e55279d96d869775e19e9f7236fc5/gitpython-3.1.49-py3-none-any.whl", hash = "sha256:024b0422d7f84d15cd794844e029ffebd4c5d42a7eb9b936b458697ef550a02c", size = 212190, upload-time = "2026-04-29T00:31:18.412Z" },
]
[[package]]
@ -3740,14 +3740,14 @@ wheels = [
[[package]]
name = "mako"
version = "1.3.11"
version = "1.3.12"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markupsafe" },
]
sdist = { url = "https://files.pythonhosted.org/packages/59/8a/805404d0c0b9f3d7a326475ca008db57aea9c5c9f2e1e39ed0faa335571c/mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069", size = 399811, upload-time = "2026-04-14T20:19:51.493Z" }
sdist = { url = "https://files.pythonhosted.org/packages/00/62/791b31e69ae182791ec67f04850f2f062716bbd205483d63a215f3e062d3/mako-1.3.12.tar.gz", hash = "sha256:9f778e93289bd410bb35daadeb4fc66d95a746f0b75777b942088b7fd7af550a", size = 400219, upload-time = "2026-04-28T19:01:08.512Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/68/a5/19d7aaa7e433713ffe881df33705925a196afb9532efc8475d26593921a6/mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77", size = 78503, upload-time = "2026-04-14T20:19:53.233Z" },
{ url = "https://files.pythonhosted.org/packages/bc/b1/a0ec7a5a9db730a08daef1fdfb8090435b82465abbf758a596f0ea88727e/mako-1.3.12-py3-none-any.whl", hash = "sha256:8f61569480282dbf557145ce441e4ba888be453c30989f879f0d652e39f53ea9", size = 78521, upload-time = "2026-04-28T19:01:10.393Z" },
]
[[package]]
@ -5359,19 +5359,19 @@ wheels = [
[[package]]
name = "pyrefly"
version = "0.62.0"
version = "0.64.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/bb/ad/8874ed25781e7dd561c6d75fb4a7becf10a18d75b074f25b845cc334f781/pyrefly-0.62.0.tar.gz", hash = "sha256:da1fbe1075dc1e6c8e3134e9370b0a0e7a296061d782cca5bf83dbb8e4c10d7c", size = 5537672, upload-time = "2026-04-20T17:12:15.718Z" }
sdist = { url = "https://files.pythonhosted.org/packages/85/99/923622d7b52ef84e83f357b19bd08dff063ccc5f4472b003105e1f308d93/pyrefly-0.64.0.tar.gz", hash = "sha256:fbfcdb0031adadc340b6c64cb41c6094c95349ee952fe3d4c143866add829172", size = 5678516, upload-time = "2026-05-06T17:28:44.056Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1b/ea/09bd9da7d5df294db800312fb415be2fefbaa5594178e9e49f44fa071aea/pyrefly-0.62.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9d78ec4f126dee1fa76215b193b964490ce10e62a32d2787a72c51623658b803", size = 13020414, upload-time = "2026-04-20T17:11:43.617Z" },
{ url = "https://files.pythonhosted.org/packages/4b/f0/f84afac4f220c4c8c801b779ee2ff28ad3f7731f4283c2e1b6ee9012e8c2/pyrefly-0.62.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2a41a34902d20756264486f9e309f22633d100261bd960feea6e858a098d985d", size = 12515659, upload-time = "2026-04-20T17:11:46.59Z" },
{ url = "https://files.pythonhosted.org/packages/40/0b/620c39cefa9ae1b25ee7a2da9d8d3c278b095649cb8435c5e01ea64f7c17/pyrefly-0.62.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4666c6b65aea662e5f77b64dc91c091b7ea5cede6aa66c0f4cbae26480403583", size = 36228332, upload-time = "2026-04-20T17:11:50.523Z" },
{ url = "https://files.pythonhosted.org/packages/2d/fb/47b8b76438c12761e509a3666cd5a99d4af7f21976ba8385feb475cbfe30/pyrefly-0.62.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1aefab798f47d37c13ded791192fee9b39a6d2b12e31f38ae06a1f80c4b26e22", size = 38995741, upload-time = "2026-04-20T17:11:54.702Z" },
{ url = "https://files.pythonhosted.org/packages/55/d2/03bd17673f61147cd5609cd7d6a1455eeccc17a07a7e141ed9931b0c42c0/pyrefly-0.62.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa986b50d56740da1d7ae7c660a505143cb9d286fa98cc7e5f4a759cc6eaa5d", size = 37205321, upload-time = "2026-04-20T17:11:58.9Z" },
{ url = "https://files.pythonhosted.org/packages/75/14/20ba7b7f2d182f9b7c1e24a3041dac9b5730ae28cfe1614a2c98706650f2/pyrefly-0.62.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32e9b175805c82ffb967e4708f4910bace7e1a12736907380cc9afdbaabb0efb", size = 41786834, upload-time = "2026-04-20T17:12:03.221Z" },
{ url = "https://files.pythonhosted.org/packages/fa/c8/5a7ba88c4fa1b5090d877f70fa1b742b921b9e7d8d3f4b6b9b1ba1820850/pyrefly-0.62.0-py3-none-win32.whl", hash = "sha256:1cd98edc20cab5bac8016c9220ee66080e39bd22e7f0e9bb3e2c4e2be1555eed", size = 12010170, upload-time = "2026-04-20T17:12:06.791Z" },
{ url = "https://files.pythonhosted.org/packages/2e/78/d8f810de010ff2ed594c630c724fd817ef430963249e9eb396ce8f785e9d/pyrefly-0.62.0-py3-none-win_amd64.whl", hash = "sha256:6994f8ee7d6720325ee52207fbdaca98a799a1efe462bb5ba90c47160f7f3e6e", size = 12861816, upload-time = "2026-04-20T17:12:09.689Z" },
{ url = "https://files.pythonhosted.org/packages/c7/a9/ac824ef6a3f50b7c0ec5974471f8f2cb205cd1edd53a5abbcf7ba37feb5d/pyrefly-0.62.0-py3-none-win_arm64.whl", hash = "sha256:362a5d47a5ac5aaa5258091e878a1759ff8b687d8cf462af1c516144f7b0108a", size = 12352977, upload-time = "2026-04-20T17:12:12.736Z" },
{ url = "https://files.pythonhosted.org/packages/b8/1c/b001b7e84a811dbb3c85e31bd4bfc3edfa3c94438140cd1d6e8c06b7c1df/pyrefly-0.64.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:683b317d8d0e815fb2ad75b7e0fa6c15eed5be4bcbc407dc13312984da3a9c47", size = 13287462, upload-time = "2026-05-06T17:28:19.169Z" },
{ url = "https://files.pythonhosted.org/packages/89/02/1e6fcd311bd7c24aaccc0afb998d584e1fa6c370e1428b4b091103760efe/pyrefly-0.64.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:96913cc4f066a7bd008b9dba8e3951234e92bb8a3a2cb1aea0e274fd2a444c55", size = 12777104, upload-time = "2026-05-06T17:28:22.047Z" },
{ url = "https://files.pythonhosted.org/packages/d6/2b/3f347b8d97c9065d6ace14a22591c8d91e64610e74e0d4f214b3025ebcf7/pyrefly-0.64.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2ae557e1b6a6a5bda844806cae10b212cf84ea786ece10d55083a0321ee1705", size = 37064924, upload-time = "2026-05-06T17:28:24.743Z" },
{ url = "https://files.pythonhosted.org/packages/73/dd/0b40175e930a96139a8e9f62a8e1db7f9a5e9df8e6cef08bf280affcb05e/pyrefly-0.64.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d062ac1744346efacd7df23c6bbff662ad29ed495923cb59ede656a306355655", size = 39719832, upload-time = "2026-05-06T17:28:28.042Z" },
{ url = "https://files.pythonhosted.org/packages/9a/4b/0afb4ad02eb67ddb299ff3f7108ceb307e520578b00e900d07f2371423ca/pyrefly-0.64.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6850b305d45121911fbe25ad56497d2e887b387ea50644ba15a8ad2a8cf855f4", size = 37861666, upload-time = "2026-05-06T17:28:31.234Z" },
{ url = "https://files.pythonhosted.org/packages/e5/1b/f5390f8678433708288afab13f043ddd021a55dba3f665360d2c9396ee04/pyrefly-0.64.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a259925620a84fe87cd30a82643ec524eeef631f0c4ec5af81a21e006c2f5b1", size = 42634235, upload-time = "2026-05-06T17:28:34.405Z" },
{ url = "https://files.pythonhosted.org/packages/47/f7/4b66934e375dde3e4d75373b1a94eb7e7c0c0c788e94267641a223930180/pyrefly-0.64.0-py3-none-win32.whl", hash = "sha256:20317f6dd97e22bc508b8dbc537e59b0ab58e384113ee61920c87ed1a6a12f62", size = 12213388, upload-time = "2026-05-06T17:28:37.146Z" },
{ url = "https://files.pythonhosted.org/packages/0a/15/653523d99795041a1be6dadf7a73225317cb2aae4b21e6df57edbce807f0/pyrefly-0.64.0-py3-none-win_amd64.whl", hash = "sha256:e88fc6a83add9b7c2224be0f74df1b0db10b3af856ae30e4e0a90ba3644c712f", size = 13136719, upload-time = "2026-05-06T17:28:39.767Z" },
{ url = "https://files.pythonhosted.org/packages/50/bb/9ea1c26b511b38a3e1eefc1bd3de7d3f65b2bbfdb59295f3244f61564a81/pyrefly-0.64.0-py3-none-win_arm64.whl", hash = "sha256:73744bd95e836abda0d08e9cdcf008142090ae0124c8f8ff477c944b60c0343c", size = 12526050, upload-time = "2026-05-06T17:28:42.077Z" },
]
[[package]]

View File

@ -155,9 +155,6 @@
}
},
"web/app/account/(commonLayout)/account-page/email-change-modal.tsx": {
"erasable-syntax-only/enums": {
"count": 1
},
"ts/no-explicit-any": {
"count": 5
}
@ -1824,26 +1821,6 @@
"count": 1
}
},
"web/app/components/base/tag-management/__tests__/panel.spec.tsx": {
"ts/no-explicit-any": {
"count": 2
}
},
"web/app/components/base/tag-management/index.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/base/tag-management/tag-item-editor.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/base/tag-management/tag-remove-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/base/text-generation/hooks.ts": {
"ts/no-explicit-any": {
"count": 1
@ -1921,11 +1898,6 @@
"count": 4
}
},
"web/app/components/billing/plan/index.tsx": {
"ts/no-explicit-any": {
"count": 2
}
},
"web/app/components/billing/pricing/assets/index.tsx": {
"no-barrel-files/no-barrel-files": {
"count": 12
@ -2359,11 +2331,6 @@
"count": 1
}
},
"web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts": {
"react/set-state-in-effect": {
"count": 1
}
},
"web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx": {
"ts/no-explicit-any": {
"count": 2
@ -2469,17 +2436,6 @@
"count": 2
}
},
"web/app/components/explore/create-app-modal/index.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
},
"unicorn/prefer-number-properties": {
"count": 1
}
},
"web/app/components/explore/item-operation/index.tsx": {
"react/set-state-in-effect": {
"count": 1
@ -4238,11 +4194,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-value-method.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter/index.tsx": {
"no-restricted-imports": {
"count": 1
@ -5099,11 +5050,6 @@
"count": 5
}
},
"web/app/education-apply/verify-state-modal.tsx": {
"react/set-state-in-effect": {
"count": 1
}
},
"web/app/forgot-password/ForgotPasswordForm.spec.tsx": {
"ts/no-explicit-any": {
"count": 5
@ -5378,11 +5324,6 @@
"count": 2
}
},
"web/service/knowledge/use-dataset.ts": {
"@tanstack/query/exhaustive-deps": {
"count": 1
}
},
"web/service/share.ts": {
"erasable-syntax-only/enums": {
"count": 1

View File

@ -4156,8 +4156,8 @@ export type GetAppsByAppIdWorkflowsDraftVariablesResponse
export type DeleteAppsByAppIdWorkflowsDraftVariablesByVariableIdData = {
body?: never
path: {
app_id: string
variable_id: string
app_id: string
}
query?: never
url: '/apps/{app_id}/workflows/draft/variables/{variable_id}'
@ -4210,8 +4210,8 @@ export type GetAppsByAppIdWorkflowsDraftVariablesByVariableIdResponse
export type PatchAppsByAppIdWorkflowsDraftVariablesByVariableIdData = {
body: WorkflowDraftVariableUpdatePayload
path: {
app_id: string
variable_id: string
app_id: string
}
query?: never
url: '/apps/{app_id}/workflows/draft/variables/{variable_id}'

View File

@ -2980,8 +2980,8 @@ export const zGetAppsByAppIdWorkflowsDraftVariablesQuery = z.object({
export const zGetAppsByAppIdWorkflowsDraftVariablesResponse = zWorkflowDraftVariableListWithoutValue
export const zDeleteAppsByAppIdWorkflowsDraftVariablesByVariableIdPath = z.object({
app_id: z.string(),
variable_id: z.string(),
app_id: z.string(),
})
/**
@ -3006,8 +3006,8 @@ export const zPatchAppsByAppIdWorkflowsDraftVariablesByVariableIdBody
= zWorkflowDraftVariableUpdatePayload
export const zPatchAppsByAppIdWorkflowsDraftVariablesByVariableIdPath = z.object({
app_id: z.string(),
variable_id: z.string(),
app_id: z.string(),
})
/**

View File

@ -255,6 +255,7 @@ export type ProcessRule = {
}
export type RetrievalModel = {
metadata_filtering_conditions?: MetadataFilteringCondition
reranking_enable: boolean
reranking_mode?: string | null
reranking_model?: RerankingModel
@ -312,6 +313,11 @@ export type Rule = {
subchunk_segmentation?: Segmentation
}
export type MetadataFilteringCondition = {
conditions?: Array<Condition> | null
logical_operator?: 'and' | 'or' | null
}
export type RerankingModel = {
reranking_model_name?: string | null
reranking_provider_name?: string | null
@ -405,6 +411,30 @@ export type Segmentation = {
separator?: string
}
export type Condition = {
comparison_operator:
| 'contains'
| 'not contains'
| 'start with'
| 'end with'
| 'is'
| 'is not'
| 'empty'
| 'not empty'
| 'in'
| 'not in'
| '='
| '≠'
| '>'
| '<'
| '≥'
| '≤'
| 'before'
| 'after'
name: string
value?: unknown
}
export type WeightKeywordSetting = {
keyword_weight: number
}
@ -1174,8 +1204,8 @@ export type PatchDatasetsByDatasetIdDocumentsStatusByActionBatchResponse
export type DeleteDatasetsByDatasetIdDocumentsByDocumentIdData = {
body?: never
path: {
dataset_id: string
document_id: string
dataset_id: string
}
query?: never
url: '/datasets/{dataset_id}/documents/{document_id}'

View File

@ -392,6 +392,46 @@ export const zProcessRule = z.object({
rules: zRule.optional(),
})
/**
* Condition
*
* Condition detail
*/
export const zCondition = z.object({
comparison_operator: z.enum([
'contains',
'not contains',
'start with',
'end with',
'is',
'is not',
'empty',
'not empty',
'in',
'not in',
'=',
'≠',
'>',
'<',
'≥',
'≤',
'before',
'after',
]),
name: z.string(),
value: z.unknown().optional(),
})
/**
* MetadataFilteringCondition
*
* Metadata Filtering Condition.
*/
export const zMetadataFilteringCondition = z.object({
conditions: z.array(zCondition).nullish(),
logical_operator: z.enum(['and', 'or']).nullish().default('and'),
})
/**
* WeightKeywordSetting
*/
@ -421,6 +461,7 @@ export const zWeightModel = z.object({
* RetrievalModel
*/
export const zRetrievalModel = z.object({
metadata_filtering_conditions: zMetadataFilteringCondition.optional(),
reranking_enable: z.boolean(),
reranking_mode: z.string().nullish(),
reranking_model: zRerankingModel.optional(),
@ -925,8 +966,8 @@ export const zPatchDatasetsByDatasetIdDocumentsStatusByActionBatchResponse = z.r
)
export const zDeleteDatasetsByDatasetIdDocumentsByDocumentIdPath = z.object({
dataset_id: z.string(),
document_id: z.string(),
dataset_id: z.string(),
})
/**

Some files were not shown because too many files have changed in this diff Show More