Merge remote-tracking branch 'origin/main'

# Conflicts:
#	api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py
This commit is contained in:
FFXN 2026-05-08 11:32:47 +08:00
commit 6c5f6699d2
1077 changed files with 84896 additions and 27793 deletions

View File

@ -1,6 +1,6 @@
---
name: frontend-query-mutation
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
---
# Frontend Query & Mutation
@ -9,22 +9,24 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
- Keep contract as the single source of truth in `web/contract/*`.
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
- Keep invalidation and mutation flow knowledge in the service layer.
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
- Keep abstractions minimal to preserve TypeScript inference.
## Workflow
1. Identify the change surface.
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
- Read both references when a task spans contract shape and runtime behavior.
2. Implement the smallest abstraction that fits the task.
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
- Extract a small shared query helper only when multiple call sites share the same extra options.
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
- Create or keep feature hooks only for real orchestration or shared domain behavior.
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
3. Preserve Dify conventions.
- Keep contract inputs in `{ params, query?, body? }` shape.
- Bind invalidation in the service-layer mutation definition.
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
## Files Commonly Touched
@ -33,7 +35,7 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
- `web/contract/marketplace.ts`
- `web/contract/router.ts`
- `web/service/client.ts`
- `web/service/use-*.ts`
- legacy `web/service/use-*.ts` files when migrating wrappers away
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
## References

View File

@ -1,4 +1,4 @@
interface:
display_name: "Frontend Query & Mutation"
short_description: "Dify TanStack Query and oRPC patterns"
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
short_description: "Dify TanStack Query, oRPC, and default option patterns"
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."

View File

@ -7,6 +7,7 @@
- Core workflow
- Query usage decision rule
- Mutation usage decision rule
- Thin hook decision rule
- Anti-patterns
- Contract rules
- Type export
@ -55,9 +56,13 @@ const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
1. Default to direct `*.queryOptions(...)` usage at the call site.
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
3. Create `web/service/use-{domain}.ts` only for orchestration.
3. Create or keep feature hooks only for orchestration.
- Combine multiple queries or mutations.
- Share domain-level derived state or invalidation helpers.
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
4. Treat `web/service/use-{domain}.ts` as legacy.
- Do not create new thin service wrappers for oRPC contracts.
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
```typescript
const invoicesBaseQueryOptions = () =>
@ -74,11 +79,37 @@ const invoiceQuery = useQuery({
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
```typescript
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
```
## Thin Hook Decision Rule
Remove thin hooks when they only rename a single oRPC query or mutation helper.
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
Use:
```typescript
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
```
Keep:
```typescript
const applyTagBindingsMutation = useApplyTagBindingsMutation()
```
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
## Anti-Patterns
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
- Do not create thin `use-*` passthrough hooks for a single endpoint.
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
## Contract Rules

View File

@ -3,6 +3,7 @@
## Table of Contents
- Conditional queries
- oRPC default options
- Cache invalidation
- Key API guide
- `mutate` vs `mutateAsync`
@ -35,9 +36,50 @@ function useBadAccessMode(appId: string | undefined) {
}
```
## oRPC Default Options
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
Place defaults at the query utility creation point in `web/service/client.ts`:
```typescript
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
path: ['console'],
experimental_defaults: {
tags: {
create: {
mutationOptions: {
onSuccess: (tag, _variables, _result, context) => {
context.client.setQueryData(
consoleQuery.tags.list.queryKey({
input: {
query: {
type: tag.type,
},
},
}),
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
)
},
},
},
},
},
})
```
Rules:
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
## Cache Invalidation
Bind invalidation in the service-layer mutation definition.
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
Use:
@ -49,7 +91,7 @@ Use:
Do not use deprecated `useInvalid` from `use-base.ts`.
```typescript
// Service layer owns cache invalidation.
// Feature orchestration owns cache invalidation only when defaults are not enough.
export const useUpdateAccessMode = () => {
const queryClient = useQueryClient()
@ -124,7 +166,7 @@ When touching old code, migrate it toward these rules:
| Old pattern | New pattern |
|---|---|
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |

3
.github/CODEOWNERS vendored
View File

@ -6,6 +6,9 @@
* @crazywoola @laipz8200 @Yeuoly
# ESLint suppression file is maintained by autofix.ci pruning.
/eslint-suppressions.json
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola

View File

@ -4,7 +4,7 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
with:
node-version-file: .nvmrc
cache: true

1
.github/labeler.yml vendored
View File

@ -6,5 +6,4 @@ web:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'

View File

@ -43,7 +43,6 @@ jobs:
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'
@ -114,7 +113,7 @@ jobs:
find . -name "*.py.bak" -type f -delete
- name: Setup web environment
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
if: github.event_name != 'merge_group'
uses: ./.github/actions/setup-web
- name: ESLint autofix

View File

@ -74,7 +74,7 @@ jobs:
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up Depot CLI
uses: depot/setup-action@v1
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Extract metadata for Docker
id: meta
@ -84,7 +84,7 @@ jobs:
- name: Build Docker image
id: build
uses: depot/build-push-action@v1
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
context: ${{ matrix.build_context }}
@ -124,10 +124,10 @@ jobs:
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Validate Docker image
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
push: false
context: ${{ matrix.build_context }}

View File

@ -44,10 +44,10 @@ jobs:
file: "web/Dockerfile"
steps:
- name: Set up Depot CLI
uses: depot/setup-action@v1
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Build Docker Image
uses: depot/build-push-action@v1
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
push: false
@ -71,10 +71,10 @@ jobs:
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Build Docker Image
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
push: false
context: ${{ matrix.context }}

View File

@ -69,7 +69,6 @@ jobs:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
@ -83,7 +82,6 @@ jobs:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'

View File

@ -83,7 +83,6 @@ jobs:
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**

View File

@ -9,7 +9,6 @@ on:
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

3
.gitignore vendored
View File

@ -219,6 +219,9 @@ node_modules
# plugin migrate
plugins.jsonl
# generated API OpenAPI specs
packages/contracts/openapi/
# mise
mise.toml

1
.npmrc
View File

@ -1 +0,0 @@
save-exact=true

View File

@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
```bash
cd dify
cd docker
cp .env.example .env
docker compose up -d
./dify-compose up -d
```
On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory.
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
#### Seeking help
@ -137,7 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases.
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana

View File

@ -113,8 +113,18 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
# Validates name encoding for non-Latin characters.
name = name.strip().encode("utf-8").decode("utf-8") if name else None
# generate random password
new_password = secrets.token_urlsafe(16)
# Generate a random password that satisfies the password policy.
# The iteration limit guards against infinite loops caused by unexpected bugs in valid_password.
for _ in range(100):
new_password = secrets.token_urlsafe(16)
try:
valid_password(new_password)
break
except Exception:
continue
else:
click.echo(click.style("Failed to generate a valid password. Please try again.", fg="red"))
return
# register account
account = RegisterService.register(

View File

@ -41,7 +41,8 @@ def guess_file_info_from_response(response: httpx.Response):
# Try to extract filename from URL
parsed_url = urllib.parse.urlparse(url)
url_path = parsed_url.path
filename = os.path.basename(url_path)
# Decode percent-encoded characters in the path segment
filename = urllib.parse.unquote(os.path.basename(url_path))
# If filename couldn't be extracted, use Content-Disposition header
if not filename:

View File

@ -1,4 +1,5 @@
import logging
import re
import uuid
from datetime import datetime
from typing import Any, Literal
@ -8,6 +9,7 @@ from flask_restx import Resource
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.datastructures import MultiDict
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@ -57,6 +59,7 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
class AppListQuery(BaseModel):
@ -66,22 +69,19 @@ class AppListQuery(BaseModel):
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
def validate_tag_ids(cls, value: list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not isinstance(value, list):
raise ValueError("Unsupported tag_ids type.")
items = [str(item).strip() for item in value if item and str(item).strip()]
if not items:
return None
@ -91,6 +91,26 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
normalized: dict[str, str | list[str]] = {}
indexed_tag_ids: list[tuple[int, str]] = []
for key in query_args:
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
if match:
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
continue
value = query_args.get(key)
if value is not None:
normalized[key] = value
if indexed_tag_ids:
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
return normalized
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@ -455,7 +475,7 @@ class AppListApi(Resource):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
args_dict = args.model_dump()
# get app list

View File

@ -60,7 +60,8 @@ _file_access_controller = DatabaseFileAccessController()
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -158,8 +159,13 @@ class WorkflowFeaturesPayload(BaseModel):
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
class WorkflowOnlineUsersQuery(BaseModel):
app_ids: str = Field(..., description="Comma-separated app IDs")
class WorkflowOnlineUsersPayload(BaseModel):
app_ids: list[str] = Field(default_factory=list, description="App IDs")
@field_validator("app_ids")
@classmethod
def normalize_app_ids(cls, app_ids: list[str]) -> list[str]:
return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip()))
class DraftWorkflowTriggerRunPayload(BaseModel):
@ -186,7 +192,7 @@ reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(WorkflowFeaturesPayload)
reg(WorkflowOnlineUsersQuery)
reg(WorkflowOnlineUsersPayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
@ -1384,19 +1390,19 @@ class DraftWorkflowTriggerRunAllApi(Resource):
@console_ns.route("/apps/workflows/online-users")
class WorkflowOnlineUsersApi(Resource):
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
@console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__])
@console_ns.doc("get_workflow_online_users")
@console_ns.doc(description="Get workflow online users")
@setup_required
@login_required
@account_initialization_required
@marshal_with(online_user_list_fields)
def get(self):
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
def post(self):
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip()))
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS:
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.")
app_ids = args.app_ids
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS:
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS} app_ids are allowed per request.")
if not app_ids:
return {"data": []}
@ -1404,13 +1410,24 @@ class WorkflowOnlineUsersApi(Resource):
_, current_tenant_id = current_account_with_tenant()
workflow_service = WorkflowService()
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]
users_json_by_app_id: dict[str, Any] = {}
for start_index in range(0, len(ordered_accessible_app_ids), WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE):
app_id_batch = ordered_accessible_app_ids[
start_index : start_index + WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE
]
pipe = redis_client.pipeline(transaction=False)
for app_id in app_id_batch:
pipe.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
users_json_batch = pipe.execute()
for app_id, users_json in zip(app_id_batch, users_json_batch):
users_json_by_app_id[app_id] = users_json
results = []
for app_id in app_ids:
if app_id not in accessible_app_ids:
continue
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
for app_id in ordered_accessible_app_ids:
users_json = users_json_by_app_id.get(app_id, {})
users = []
for _, user_info_json in users_json.items():

View File

@ -75,14 +75,15 @@ console_ns.schema_model(
def _convert_values_to_json_serializable_object(value: Segment):
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
match value:
case FileSegment():
return value.value.model_dump()
case ArrayFileSegment():
return [i.model_dump() for i in value.value]
case SegmentGroup():
return [_convert_values_to_json_serializable_object(i) for i in value.value]
case _:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable):

View File

@ -32,12 +32,7 @@ class TagBindingPayload(BaseModel):
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1)
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
@ -75,7 +70,6 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam,
TagResponse,
)
@ -184,13 +178,13 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
@ -211,54 +205,15 @@ class TagBindingCollectionApi(Resource):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/<uuid:id>")
class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
class TagBindingRemoveApi(Resource):
"""Batch resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
@console_ns.doc("remove_tag_bindings")
@console_ns.doc(description="Remove one or more tag bindings from a target.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _remove_tag_binding()
return _remove_tag_bindings()

View File

@ -8,6 +8,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import supported_language
@ -45,6 +46,8 @@ from libs.helper import EmailStr, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
from models.enums import CreatorUserRole
from models.model import UploadFile
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -322,9 +325,24 @@ class AccountAvatarApi(Resource):
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
avatar = args.avatar
avatar_url = file_helpers.get_signed_file_url(args.avatar)
if avatar.startswith(("http://", "https://")):
return {"avatar_url": avatar}
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
if upload_file is None:
raise NotFound("Avatar file not found")
if upload_file.tenant_id != current_tenant_id:
raise NotFound("Avatar file not found")
if upload_file.created_by_role != CreatorUserRole.ACCOUNT or upload_file.created_by != current_user.id:
raise NotFound("Avatar file not found")
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {"avatar_url": avatar_url}
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])

View File

@ -2,7 +2,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel):
class TagUnbindingPayload(BaseModel):
tag_id: str
"""Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally."""
tag_ids: list[str] = Field(default_factory=list)
tag_id: str | None = None
target_id: str
@model_validator(mode="before")
@classmethod
def normalize_legacy_tag_id(cls, data: object) -> object:
if not isinstance(data, dict):
return data
if not data.get("tag_ids") and data.get("tag_id"):
return {**data, "tag_ids": [data["tag_id"]]}
return data
@model_validator(mode="after")
def validate_tag_ids(self) -> "TagUnbindingPayload":
if not self.tag_ids:
raise ValueError("Tag IDs is required.")
return self
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc("unbind_dataset_tags")
@service_api_ns.doc(description="Unbind tags from a dataset")
@service_api_ns.doc(
responses={
204: "Tag unbound successfully",
204: "Tags unbound successfully",
401: "Unauthorized - invalid API token",
403: "Forbidden - insufficient permissions",
}
@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204

View File

@ -468,15 +468,98 @@ class DocumentAddByFileApi(DatasetApiResource):
return documents_and_batch_fields, 200
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
"""Update a document from an uploaded file for canonical and deprecated routes."""
dataset_id_str = str(dataset_id)
tenant_id_str = str(tenant_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
args: dict[str, object] = {}
if "data" in request.form:
args = json.loads(request.form["data"])
if "doc_form" not in args:
args["doc_form"] = dataset.chunk_structure or "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if "file" in request.files:
# save file info
file = request.files["file"]
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
@service_api_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
)
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
"""Deprecated resource aliases for file document updates."""
@service_api_ns.doc("update_document_by_file")
@service_api_ns.doc(description="Update an existing document by uploading a file")
@service_api_ns.doc("update_document_by_file_deprecated")
@service_api_ns.doc(deprecated=True)
@service_api_ns.doc(
description=(
"Deprecated legacy alias for updating an existing document by uploading a file. "
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
)
)
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
@ -487,82 +570,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
args = {}
if "data" in request.form:
args = json.loads(request.form["data"])
if "doc_form" not in args:
args["doc_form"] = dataset.chunk_structure or "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if "file" in request.files:
# save file info
file = request.files["file"]
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by file through the deprecated file-update aliases."""
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
@ -876,6 +886,22 @@ class DocumentApi(DatasetApiResource):
return response
@service_api_ns.doc("update_document_by_file")
@service_api_ns.doc(description="Update an existing document by uploading a file")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Document updated successfully",
401: "Unauthorized - invalid API token",
404: "Document not found",
}
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by file on the canonical document resource."""
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.doc("delete_document")
@service_api_ns.doc(description="Delete a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})

View File

@ -23,7 +23,7 @@ from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App
from models.model import App, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -69,12 +69,12 @@ class AudioApi(WebApiResource):
500: "Internal Server Error",
}
)
def post(self, app_model: App, end_user):
def post(self, app_model: App, end_user: EndUser):
"""Convert audio to text"""
file = request.files["file"]
try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.external_user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
500: "Internal Server Error",
}
)
def post(self, app_model: App, end_user):
def post(self, app_model: App, end_user: EndUser):
"""Convert text to audio"""
try:
payload = TextToAudioPayload.model_validate(web_ns.payload or {})

View File

@ -151,6 +151,12 @@ def deserialize_response(raw_data: bytes) -> Response:
response = Response(response=body, status=status_code)
# Replace Flask's default headers (e.g. Content-Type, Content-Length) with the
# parsed ones so we faithfully reproduce the original response. Use Headers.add
# rather than dict-style assignment so that repeated headers such as Set-Cookie
# (and any other multi-valued header per RFC 9110) are preserved instead of
# being overwritten.
response.headers.clear()
for line in lines[1:]:
if not line:
continue
@ -158,6 +164,6 @@ def deserialize_response(raw_data: bytes) -> Response:
if ":" not in line_str:
continue
name, value = line_str.split(":", 1)
response.headers[name] = value.strip()
response.headers.add(name, value.strip())
return response

View File

@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
@ -445,7 +445,7 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
provider_name_to_provider_records_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
providers = session.scalars(stmt)
for provider in providers:
@ -462,7 +462,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_records_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
@ -478,7 +478,7 @@ class ProviderManager:
:return:
"""
provider_name_to_preferred_provider_type_records_dict = {}
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
preferred_provider_types = session.scalars(stmt)
provider_name_to_preferred_provider_type_records_dict = {
@ -496,7 +496,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_settings_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
@ -514,7 +514,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_credentials_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
provider_model_credentials = session.scalars(stmt)
for provider_model_credential in provider_model_credentials:
@ -544,7 +544,7 @@ class ProviderManager:
return {}
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_configs = session.scalars(stmt)
for provider_load_balancing_config in provider_load_balancing_configs:
@ -578,7 +578,7 @@ class ProviderManager:
:param provider_name: provider name
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = (
select(ProviderCredential)
.where(
@ -608,7 +608,7 @@ class ProviderManager:
:param model_type: model type
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = (
select(ProviderModelCredential)
.where(

View File

@ -217,10 +217,11 @@ class RetrievalService:
"""Deduplicate documents in O(n) while preserving first-seen order.
Rules:
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
- For non-dify documents (or dify without doc_id): deduplicate by content key
(provider, page_content), keeping the first occurrence.
- If metadata["doc_id"] exists (any provider): deduplicate by (provider, doc_id) key;
keep the doc with the highest metadata["score"] among duplicates. If a later duplicate
has no score, ignore it.
- If metadata["doc_id"] is absent: deduplicate by content key (provider, page_content),
keeping the first occurrence.
"""
if not documents:
return documents
@ -231,11 +232,10 @@ class RetrievalService:
order: list[tuple] = []
for doc in documents:
is_dify = doc.provider == "dify"
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
doc_id = (doc.metadata or {}).get("doc_id")
if is_dify and doc_id:
key = ("dify", doc_id)
if doc_id:
key = (doc.provider or "dify", doc_id)
if key not in chosen:
chosen[key] = doc
order.append(key)

View File

@ -144,8 +144,20 @@ class Vector:
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
return get_vector_factory_class(vector_type)
@staticmethod
def _filter_empty_text_documents(documents: list[Document]) -> list[Document]:
filtered_documents = [document for document in documents if document.page_content.strip()]
skipped_count = len(documents) - len(filtered_documents)
if skipped_count:
logger.warning("skip %d empty documents before vector embedding", skipped_count)
return filtered_documents
def create(self, texts: list | None = None, **kwargs):
if texts:
texts = self._filter_empty_text_documents(texts)
if not texts:
return
start = time.time()
logger.info("start embedding %s texts %s", len(texts), start)
batch_size = 1000
@ -203,8 +215,14 @@ class Vector:
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
documents = self._filter_empty_text_documents(documents)
if not documents:
return
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
if not documents:
return
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)

View File

@ -1078,6 +1078,13 @@ class ToolManager:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
if variable_pool:
config = tool_configurations.get(parameter.name, {})
selector_value = cls._extract_runtime_selector_value(parameter, config)
if selector_value is not None:
# Selector parameters carry structured dictionaries, not scalar ToolInput values.
runtime_parameters[parameter.name] = selector_value
continue
if not (config and isinstance(config, dict) and config.get("value") is not None):
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
@ -1105,5 +1112,39 @@ class ToolManager:
runtime_parameters[parameter.name] = value
return runtime_parameters
@classmethod
def _extract_runtime_selector_value(cls, parameter: ToolParameter, config: Any) -> dict[str, Any] | None:
if parameter.type not in {
ToolParameter.ToolParameterType.MODEL_SELECTOR,
ToolParameter.ToolParameterType.APP_SELECTOR,
}:
return None
if not isinstance(config, dict):
return None
input_value = config.get("value")
if isinstance(input_value, dict) and cls._is_selector_value(parameter, input_value):
return cast("dict[str, Any]", parameter.init_frontend_parameter(input_value))
if cls._is_selector_value(parameter, config):
selector_value = dict(config)
selector_value.pop("type", None)
selector_value.pop("value", None)
return cast("dict[str, Any]", parameter.init_frontend_parameter(selector_value))
return None
@classmethod
def _is_selector_value(cls, parameter: ToolParameter, value: Mapping[str, Any]) -> bool:
if parameter.type == ToolParameter.ToolParameterType.MODEL_SELECTOR:
return (
isinstance(value.get("provider"), str)
and isinstance(value.get("model"), str)
and isinstance(value.get("model_type"), str)
)
if parameter.type == ToolParameter.ToolParameterType.APP_SELECTOR:
return isinstance(value.get("app_id"), str)
return False
ToolManager.load_hardcoded_providers_cache()

View File

@ -272,6 +272,14 @@ def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, A
normalized_tool_configurations[name] = value
continue
selector_value = _extract_selector_configuration(value)
if selector_value is not None:
# Model/app selectors are dictionaries even when they come through the legacy tool configuration path.
# Move them to tool_parameters so graph validation does not flatten them as primitive constants.
found_legacy_tool_inputs = True
normalized_tool_parameters.setdefault(name, {"type": "constant", "value": selector_value})
continue
input_type = value.get("type")
input_value = value.get("value")
if input_type not in {"mixed", "variable", "constant"}:
@ -310,6 +318,28 @@ def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: An
return None
def _extract_selector_configuration(value: Mapping[str, Any]) -> dict[str, Any] | None:
input_value = value.get("value")
if isinstance(input_value, Mapping) and _is_selector_configuration(input_value):
return dict(input_value)
if _is_selector_configuration(value):
selector_value = dict(value)
selector_value.pop("type", None)
selector_value.pop("value", None)
return selector_value
return None
def _is_selector_configuration(value: Mapping[str, Any]) -> bool:
return (
isinstance(value.get("provider"), str)
and isinstance(value.get("model"), str)
and isinstance(value.get("model_type"), str)
) or isinstance(value.get("app_id"), str)
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)

View File

@ -365,7 +365,8 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
adapted_node_config = adapt_node_config_for_graph(node_config)
typed_node_config = NodeConfigDictAdapter.validate_python(adapted_node_config)
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
@ -373,6 +374,11 @@ class DifyNodeFactory(NodeFactory):
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
config_for_node_init: BaseNodeData | dict[str, Any]
if isinstance(resolved_node_data, BaseNodeData):
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
else:
config_for_node_init = resolved_node_data
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@ -442,7 +448,7 @@ class DifyNodeFactory(NodeFactory):
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
node_id=node_id,
config=resolved_node_data,
config=config_for_node_init,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@ -474,10 +480,7 @@ class DifyNodeFactory(NodeFactory):
include_retriever_attachment_loader: bool,
include_jinja2_template_renderer: bool,
) -> dict[str, object]:
validated_node_data = cast(
LLMCompatibleNodeData,
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
)
validated_node_data = cast(LLMCompatibleNodeData, node_data)
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
node_init_kwargs: dict[str, object] = {
"credentials_provider": self._llm_credentials_provider,

View File

@ -501,11 +501,15 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
@staticmethod
def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec:
tool_configurations = dict(node_data.tool_configurations)
tool_configurations.update(
{name: tool_input.model_dump(mode="python") for name, tool_input in node_data.tool_parameters.items()}
)
return _WorkflowToolRuntimeSpec(
provider_type=CoreToolProviderType(node_data.provider_type.value),
provider_id=node_data.provider_id,
tool_name=node_data.tool_name,
tool_configurations=dict(node_data.tool_configurations),
tool_configurations=tool_configurations,
credential_id=node_data.credential_id,
)

View File

@ -3,6 +3,7 @@ import logging
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from events.app_event import app_draft_workflow_was_synced
from graphon.nodes import BuiltinNodeTypes
from graphon.nodes.tool.entities import ToolEntity
@ -19,7 +20,8 @@ def handle(sender, **kwargs):
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL:
try:
tool_entity = ToolEntity.model_validate(node_data["data"])
adapted_node_data = adapt_node_config_for_graph(node_data)
tool_entity = ToolEntity.model_validate(adapted_node_data["data"])
provider_type = ToolProviderType(tool_entity.provider_type.value)
tool_runtime = ToolManager.get_tool_runtime(
provider_type=provider_type,

View File

@ -1,7 +1,9 @@
from flask import Flask
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
def init_app(app):
def init_app(app: Flask):
with app.app_context():
configure_session_factory(db.engine)

View File

@ -298,7 +298,7 @@ def _build_from_datasource_file(
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
detected_file_type = standardize_file_type(extension=extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),

View File

@ -19,8 +19,13 @@ from werkzeug.http import parse_options_header
from core.helper import ssrf_proxy
def extract_filename(url_path: str, content_disposition: str | None) -> str | None:
"""Extract a safe filename from Content-Disposition or the request URL path."""
def extract_filename(url_or_path: str, content_disposition: str | None) -> str | None:
"""Extract a safe filename from Content-Disposition or the request URL path.
Handles full URLs, paths with query strings, hash fragments, and percent-encoded segments.
Query strings and hash fragments are stripped from the URL before extracting the basename.
Percent-encoded characters in the path are decoded safely.
"""
filename: str | None = None
if content_disposition:
filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
@ -47,8 +52,13 @@ def extract_filename(url_path: str, content_disposition: str | None) -> str | No
filename = urllib.parse.unquote(raw)
if not filename:
candidate = os.path.basename(url_path)
filename = urllib.parse.unquote(candidate) if candidate else None
# Parse the URL to extract just the path, stripping query strings and fragments
# This handles both full URLs and bare paths
parsed = urllib.parse.urlparse(url_or_path)
path = parsed.path
candidate = os.path.basename(path)
# Decode percent-encoded characters, with safe fallback for malformed input
filename = urllib.parse.unquote(candidate, errors="replace") if candidate else None
if filename:
filename = os.path.basename(filename)

View File

@ -1,9 +0,0 @@
from typing import TypeGuard
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
return isinstance(v, dict)
def is_str(v: object) -> TypeGuard[str]:
return isinstance(v, str)

View File

@ -9,11 +9,11 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
from core.db.session_factory import session_factory
from graphon.model_runtime.entities.model_entities import ModelType
from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
from .types import EnumText, LongText, StringUUID
@ -82,7 +82,8 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
with session_factory.create_session() as session:
return session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
@property
def credential_name(self):
@ -145,9 +146,10 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
with session_factory.create_session() as session:
return session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
@property
def credential_name(self):

View File

@ -1568,12 +1568,14 @@ class WorkflowDraftVariable(Base):
),
)
# Relationship to WorkflowDraftVariableFile
# WorkflowDraftVariableFile uses TypeBase while WorkflowDraftVariable uses Base, so the relationship
# must resolve the class object lazily instead of relying on string lookup across registries.
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
lambda: WorkflowDraftVariableFile,
foreign_keys=[file_id],
lazy="raise",
uselist=False,
primaryjoin="WorkflowDraftVariableFile.id == WorkflowDraftVariable.file_id",
primaryjoin=lambda: orm.foreign(WorkflowDraftVariable.file_id) == WorkflowDraftVariableFile.id,
)
# Cache for deserialized value
@ -1892,7 +1894,7 @@ class WorkflowDraftVariable(Base):
return self.last_edited_at is not None
class WorkflowDraftVariableFile(Base):
class WorkflowDraftVariableFile(TypeBase):
"""Stores metadata about files associated with large workflow draft variables.
This model acts as an intermediary between WorkflowDraftVariable and UploadFile,
@ -1906,18 +1908,7 @@ class WorkflowDraftVariableFile(Base):
__tablename__ = "workflow_draft_variable_files"
# Primary key
id: Mapped[str] = mapped_column(
StringUUID,
primary_key=True,
default=lambda: str(uuidv7()),
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=naive_utc_now,
server_default=func.current_timestamp(),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default_factory=lambda: str(uuidv7()), init=False)
tenant_id: Mapped[str] = mapped_column(
StringUUID,
@ -1969,15 +1960,23 @@ class WorkflowDraftVariableFile(Base):
nullable=False,
)
# Relationship to UploadFile
# Rows are created with `upload_file_id`; callers should load this relationship explicitly when needed.
upload_file: Mapped["UploadFile"] = orm.relationship(
UploadFile,
foreign_keys=[upload_file_id],
lazy="raise",
init=False,
uselist=False,
primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id,
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default_factory=naive_utc_now,
server_default=func.current_timestamp(),
)
def is_system_variable_editable(name: str) -> bool:
return name in _EDITABLE_SYSTEM_VARIABLE

View File

@ -246,8 +246,18 @@ class TidbService:
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
if not cluster_info.qdrant_endpoint:
cluster_info.qdrant_endpoint = TidbService.extract_qdrant_endpoint(
item
) or TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
if cluster_info.qdrant_endpoint:
cluster_info.status = TidbAuthBindingStatus.ACTIVE
else:
logger.warning(
"Cluster %s is ACTIVE but qdrant endpoint is not ready; will retry later",
item["clusterId"],
)
db.session.add(cluster_info)
db.session.commit()
else:

View File

@ -1,8 +1,11 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
from models.enums import TidbAuthBindingStatus
class TestExtractQdrantEndpoint:
"""Unit tests for TidbService.extract_qdrant_endpoint."""
@ -216,3 +219,86 @@ class TestBatchCreateEdgeCases:
private_key="priv",
region="us-east-1",
)
class TestBatchUpdateTidbServerlessClusterStatus:
"""Verify that status updates only expose clusters after qdrant endpoint is ready."""
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_sets_active_when_batch_response_contains_endpoint(self, mock_http, mock_db):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {
"clusters": [
{
"clusterId": "c-1",
"state": "ACTIVE",
"userPrefix": "pfx",
"endpoints": {"public": {"host": "gw.tidbcloud.com"}},
}
]
},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
assert binding.status == TidbAuthBindingStatus.ACTIVE
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_fetches_endpoint_when_batch_response_omits_it(self, mock_http, mock_db, mock_fetch_endpoint):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
assert binding.status == TidbAuthBindingStatus.ACTIVE
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_keeps_creating_when_endpoint_is_not_ready(self, mock_http, mock_db, mock_fetch_endpoint):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint is None
assert binding.status == TidbAuthBindingStatus.CREATING
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()

View File

@ -1,12 +1,12 @@
[project]
name = "dify-api"
version = "1.13.3"
version = "1.14.0"
requires-python = "~=3.12.0"
dependencies = [
# Legacy: mature and widely deployed
"bleach>=6.3.0",
"boto3>=1.42.96",
"boto3>=1.43.3",
"celery>=5.6.3",
"croniter>=6.2.2",
"flask>=3.1.3,<4.0.0",
@ -14,7 +14,7 @@ dependencies = [
"gevent>=26.4.0",
"gevent-websocket>=0.10.1",
"gmpy2>=2.3.0",
"google-api-python-client>=2.194.0",
"google-api-python-client>=2.195.0",
"gunicorn>=25.3.0",
"psycogreen>=1.0.2",
"psycopg2-binary>=2.9.12",
@ -31,7 +31,7 @@ dependencies = [
"flask-migrate>=4.1.0,<5.0.0",
"flask-orjson>=2.0.0,<3.0.0",
"flask-restx>=1.3.2,<2.0.0",
"google-cloud-aiplatform>=1.148.1,<2.0.0",
"google-cloud-aiplatform>=1.149.0,<2.0.0",
"httpx[socks]>=0.28.1,<1.0.0",
"opentelemetry-distro>=0.62b1,<1.0.0",
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
@ -127,7 +127,7 @@ dev = [
"testcontainers>=4.14.2",
"types-aiofiles>=25.1.0",
"types-beautifulsoup4>=4.12.0",
"types-cachetools>=6.2.0",
"types-cachetools>=7.0.0.20260503",
"types-colorama>=0.4.15",
"types-defusedxml>=0.7.0",
"types-deprecated>=1.3.1",
@ -135,7 +135,7 @@ dev = [
"types-flask-cors>=6.0.0",
"types-flask-migrate>=4.1.0",
"types-gevent>=26.4.0",
"types-greenlet>=3.4.0",
"types-greenlet>=3.5.0.20260428",
"types-html5lib>=1.1.11",
"types-markdown>=3.10.2",
"types-oauthlib>=3.3.0",
@ -143,7 +143,7 @@ dev = [
"types-olefile>=0.47.0",
"types-openpyxl>=3.1.5",
"types-pexpect>=4.9.0",
"types-protobuf>=7.34.1",
"types-protobuf>=7.34.1.20260503",
"types-psutil>=7.2.2",
"types-psycopg2>=2.9.21.20260422",
"types-pygments>=2.20.0",
@ -158,11 +158,11 @@ dev = [
"types-tensorflow>=2.18.0.20260408",
"types-tqdm>=4.67.3.20260408",
"types-ujson>=5.10.0",
"boto3-stubs>=1.42.96",
"boto3-stubs>=1.43.2",
"types-jmespath>=1.1.0.20260408",
"hypothesis>=6.152.3",
"hypothesis>=6.152.4",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=2.0.0.20260408",
"types_cffi>=2.0.0.20260429",
"types_setuptools>=82.0.0.20260408",
"pandas-stubs>=3.0.0",
"scipy-stubs>=1.17.1.4",
@ -174,7 +174,7 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.62.0",
"pyrefly>=0.64.0",
"xinference-client>=2.7.0",
]
@ -184,7 +184,7 @@ dev = [
############################################################
storage = [
"azure-storage-blob>=12.28.0",
"bce-python-sdk>=0.9.70",
"bce-python-sdk>=0.9.71",
"cos-python-sdk-v5>=1.9.42",
"esdk-obs-python>=3.22.2",
"google-cloud-storage>=3.10.1",

View File

@ -3,6 +3,7 @@ from typing import Any, Literal
from pydantic import BaseModel, field_validator
from core.rag.entities import Rule
from core.rag.entities.metadata_entities import MetadataFilteringCondition
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -83,6 +84,7 @@ class RetrievalModel(BaseModel):
score_threshold_enabled: bool
score_threshold: float | None = None
weights: WeightModel | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
class MetaDataConfig(BaseModel):

View File

@ -1,9 +1,11 @@
import uuid
from typing import cast
import sqlalchemy as sa
from flask_login import current_user
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from sqlalchemy import delete, func, select
from sqlalchemy.engine import CursorResult
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
@ -29,7 +31,7 @@ class TagBindingCreatePayload(BaseModel):
class TagBindingDeletePayload(BaseModel):
tag_id: str
tag_ids: list[str] = Field(min_length=1)
target_id: str
type: TagType
@ -178,13 +180,18 @@ class TagService:
@staticmethod
def delete_tag_binding(payload: TagBindingDeletePayload):
TagService.check_target_exists(payload.type, payload.target_id)
tag_binding = db.session.scalar(
select(TagBinding)
.where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id)
.limit(1)
result = cast(
CursorResult,
db.session.execute(
delete(TagBinding).where(
TagBinding.target_id == payload.target_id,
TagBinding.tag_id.in_(payload.tag_ids),
TagBinding.tenant_id == current_user.current_tenant_id,
)
),
)
if tag_binding:
db.session.delete(tag_binding)
if result.rowcount:
db.session.commit()
@staticmethod

View File

@ -1083,10 +1083,9 @@ class DraftVariableSaver:
mimetype=content_type,
user=self._user,
)
assert self._user.current_tenant_id
# Create WorkflowDraftVariableFile record
variable_file = WorkflowDraftVariableFile(
id=uuidv7(),
upload_file_id=upload_file.id,
size=original_size,
length=original_length,
@ -1095,6 +1094,7 @@ class DraftVariableSaver:
tenant_id=self._user.current_tenant_id,
user_id=self._user.id,
)
variable_file.id = str(uuidv7())
engine = bind = self._session.get_bind()
assert isinstance(engine, Engine)
with sessionmaker(bind=engine, expire_on_commit=False).begin() as session:

View File

@ -433,7 +433,7 @@ def flask_app_with_containers(set_up_containers_and_env) -> Flask:
@pytest.fixture
def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]:
def flask_req_ctx_with_containers(flask_app_with_containers: Flask) -> Generator[None, None, None]:
"""
Request context fixture for containerized Flask application.
@ -454,7 +454,7 @@ def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None,
@pytest.fixture
def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]:
def test_client_with_containers(flask_app_with_containers: Flask) -> Generator[FlaskClient, None, None]:
"""
Test client fixture for containerized Flask application.
@ -475,7 +475,7 @@ def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskCli
@pytest.fixture
def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]:
def db_session_with_containers(flask_app_with_containers: Flask) -> Generator[Session, None, None]:
"""
Database session fixture for containerized testing.

View File

@ -7,6 +7,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, NotFound
@ -69,7 +70,7 @@ def _unwrap(func):
class TestCompletionEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_completion_create_payload(self):
@ -86,7 +87,7 @@ class TestCompletionEndpoints:
)
assert payload.query == "hi"
def test_completion_api_success(self, app, monkeypatch):
def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
@ -116,7 +117,7 @@ class TestCompletionEndpoints:
assert resp == {"result": {"text": "ok"}}
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
@ -142,7 +143,7 @@ class TestCompletionEndpoints:
with pytest.raises(NotFound):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
@ -166,7 +167,7 @@ class TestCompletionEndpoints:
with pytest.raises(completion_module.ProviderNotInitializeError):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_quota_exceeded(self, app, monkeypatch):
def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
@ -193,10 +194,10 @@ class TestCompletionEndpoints:
class TestAppEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch):
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = app_module.AppApi()
method = _unwrap(api.put)
payload = {
@ -234,7 +235,7 @@ class TestAppEndpoints:
}
)
def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch):
def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = app_module.AppIconApi()
method = _unwrap(api.post)
payload = {
@ -266,7 +267,7 @@ class TestAppEndpoints:
class TestOpsTraceEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_ops_trace_query_basic(self):
@ -277,7 +278,7 @@ class TestOpsTraceEndpoints:
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
assert payload.tracing_config["api_key"] == "k"
def test_trace_app_config_get_empty(self, app, monkeypatch):
def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.get)
@ -292,7 +293,7 @@ class TestOpsTraceEndpoints:
assert result == {"has_not_configured": True}
def test_trace_app_config_post_invalid(self, app, monkeypatch):
def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.post)
@ -309,7 +310,7 @@ class TestOpsTraceEndpoints:
with pytest.raises(BadRequest):
method(app_id="app-1")
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.delete)
@ -326,7 +327,7 @@ class TestOpsTraceEndpoints:
class TestSiteEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_site_response_structure(self):
@ -337,7 +338,7 @@ class TestSiteEndpoints:
payload = AppSiteUpdatePayload(default_language="en-US")
assert payload.default_language == "en-US"
def test_app_site_update_post(self, app, monkeypatch):
def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = site_module.AppSite()
method = _unwrap(api.post)
@ -375,7 +376,7 @@ class TestSiteEndpoints:
assert isinstance(result, dict)
assert result["title"] == "My Site"
def test_app_site_access_token_reset(self, app, monkeypatch):
def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = site_module.AppSiteAccessTokenReset()
method = _unwrap(api.post)
@ -427,7 +428,7 @@ class TestWorkflowEndpoints:
class TestWorkflowAppLogEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_workflow_app_log_query(self):
@ -438,7 +439,7 @@ class TestWorkflowAppLogEndpoints:
query = WorkflowAppLogQuery(detail="true")
assert query.detail is True
def test_workflow_app_log_api_get(self, app, monkeypatch):
def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_app_log_module.WorkflowAppLogApi()
method = _unwrap(api.get)
@ -477,14 +478,14 @@ class TestWorkflowAppLogEndpoints:
class TestWorkflowDraftVariableEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_workflow_variable_creation(self):
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
assert payload.name == "var1"
def test_workflow_variable_collection_get(self, app, monkeypatch):
def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
method = _unwrap(api.get)
@ -529,7 +530,7 @@ class TestWorkflowDraftVariableEndpoints:
class TestWorkflowStatisticEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_workflow_statistic_time_range(self):
@ -541,7 +542,7 @@ class TestWorkflowStatisticEndpoints:
assert query.start is None
assert query.end is None
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
def test_workflow_daily_runs_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
@ -567,7 +568,7 @@ class TestWorkflowStatisticEndpoints:
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
def test_workflow_daily_terminals_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
@ -598,7 +599,7 @@ class TestWorkflowStatisticEndpoints:
class TestWorkflowTriggerEndpoints:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_webhook_trigger_payload(self):
@ -608,7 +609,7 @@ class TestWorkflowTriggerEndpoints:
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
assert enable_payload.enable_trigger is True
def test_webhook_trigger_api_get(self, app, monkeypatch):
def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_trigger_module.WebhookTriggerApi()
method = _unwrap(api.get)

View File

@ -6,6 +6,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
@ -36,10 +37,10 @@ def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
class TestAppImportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -57,7 +58,7 @@ class TestAppImportApi:
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -75,7 +76,7 @@ class TestAppImportApi:
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -96,7 +97,7 @@ class TestAppImportApi:
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -121,7 +122,7 @@ class TestAppImportApi:
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
@ -149,10 +150,10 @@ class TestAppImportApi:
class TestAppImportConfirmApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportConfirmApi()
method = _unwrap(api.post)
@ -172,10 +173,10 @@ class TestAppImportConfirmApi:
class TestAppImportCheckDependenciesApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportCheckDependenciesApi()
method = _unwrap(api.get)

View File

@ -6,6 +6,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.email_register import (
EmailRegisterCheckApi,
@ -16,7 +17,7 @@ from services.account_service import AccountService
@pytest.fixture
def app(flask_app_with_containers):
def app(flask_app_with_containers: Flask):
return flask_app_with_containers
@ -33,7 +34,7 @@ class TestEmailRegisterSendEmailApi:
mock_is_freeze,
mock_send_mail,
mock_get_account,
app,
app: Flask,
):
mock_send_mail.return_value = "token-123"
mock_is_freeze.return_value = False
@ -75,7 +76,7 @@ class TestEmailRegisterCheckApi:
mock_revoke,
mock_generate_token,
mock_reset_rate,
app,
app: Flask,
):
mock_rate_limit_check.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
@ -120,7 +121,7 @@ class TestEmailRegisterResetApi:
mock_create_account,
mock_login,
mock_reset_login_rate,
app,
app: Flask,
):
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
mock_create_account.return_value = MagicMock()

View File

@ -6,6 +6,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.forgot_password import (
ForgotPasswordCheckApi,
@ -16,7 +17,7 @@ from services.account_service import AccountService
@pytest.fixture
def app(flask_app_with_containers):
def app(flask_app_with_containers: Flask):
return flask_app_with_containers
@ -31,7 +32,7 @@ class TestForgotPasswordSendEmailApi:
mock_is_ip_limit,
mock_send_email,
mock_get_account,
app,
app: Flask,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
@ -80,7 +81,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
app: Flask,
):
mock_rate_limit_check.return_value = False
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
@ -123,7 +124,7 @@ class TestForgotPasswordResetApi:
mock_db,
mock_get_account,
mock_update_account,
app,
app: Flask,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
mock_account = MagicMock()

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import (
OAuthCallback,
@ -21,7 +22,7 @@ from services.errors.account import AccountRegisterError
class TestGetOAuthProviders:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.mark.parametrize(
@ -65,7 +66,7 @@ class TestOAuthLogin:
return OAuthLogin()
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -89,7 +90,7 @@ class TestOAuthLogin:
mock_redirect,
mock_get_providers,
resource,
app,
app: Flask,
mock_oauth_provider,
invite_token,
expected_token,
@ -130,7 +131,7 @@ class TestOAuthCallback:
return OAuthCallback()
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -164,7 +165,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
@ -217,7 +218,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
@ -261,7 +262,7 @@ class TestOAuthCallback:
mock_tenant_service,
mock_account_service,
resource,
app,
app: Flask,
oauth_setup,
account_status,
expected_redirect,
@ -300,7 +301,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
@ -336,7 +337,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
"""Defensive test for CLOSED account status handling in OAuth callback.
@ -394,7 +395,7 @@ class TestOAuthCallback:
class TestAccountGeneration:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -465,7 +466,7 @@ class TestAccountGeneration:
mock_register_service,
mock_feature_service,
mock_get_account,
app,
app: Flask,
user_info,
mock_account,
allow_register,
@ -504,7 +505,7 @@ class TestAccountGeneration:
mock_register_service,
mock_feature_service,
mock_get_account,
app,
app: Flask,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
mock_feature_service.get_system_features.return_value.is_allow_register = True
@ -529,7 +530,7 @@ class TestAccountGeneration:
mock_feature_service,
mock_tenant_service,
mock_get_account,
app,
app: Flask,
user_info,
mock_account,
):

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.error import (
EmailCodeError,
@ -25,7 +26,7 @@ class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -46,7 +47,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
app,
app: Flask,
mock_account,
):
# Arrange
@ -68,7 +69,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email.assert_called_once()
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app: Flask):
"""
Test password reset email blocked by IP rate limit.
@ -104,7 +105,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
app,
app: Flask,
mock_account,
language_input,
expected_language,
@ -138,7 +139,7 @@ class TestForgotPasswordCheckApi:
"""Test cases for verifying password reset codes."""
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@ -153,7 +154,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
app,
app: Flask,
):
"""
Test successful verification code validation.
@ -200,7 +201,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
app,
app: Flask,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
@ -221,7 +222,7 @@ class TestForgotPasswordCheckApi:
mock_reset_rate_limit.assert_called_once_with("user@example.com")
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
def test_verify_code_rate_limited(self, mock_is_rate_limit, app: Flask):
"""
Test code verification blocked by rate limit.
@ -244,7 +245,7 @@ class TestForgotPasswordCheckApi:
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app: Flask):
"""
Test code verification with invalid token.
@ -267,7 +268,7 @@ class TestForgotPasswordCheckApi:
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app: Flask):
"""
Test code verification with mismatched email.
@ -292,7 +293,7 @@ class TestForgotPasswordCheckApi:
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app: Flask):
"""
Test code verification with incorrect code.
@ -321,7 +322,7 @@ class TestForgotPasswordResetApi:
"""Test cases for resetting password with verified token."""
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -344,7 +345,7 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_revoke_token,
mock_get_data,
app,
app: Flask,
mock_account,
):
"""
@ -375,7 +376,7 @@ class TestForgotPasswordResetApi:
mock_revoke_token.assert_called_once_with("valid_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_mismatch(self, mock_get_data, app):
def test_reset_password_mismatch(self, mock_get_data, app: Flask):
"""
Test password reset with mismatched passwords.
@ -397,7 +398,7 @@ class TestForgotPasswordResetApi:
api.post()
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_invalid_token(self, mock_get_data, app):
def test_reset_password_invalid_token(self, mock_get_data, app: Flask):
"""
Test password reset with invalid token.
@ -418,7 +419,7 @@ class TestForgotPasswordResetApi:
api.post()
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_password_wrong_phase(self, mock_get_data, app):
def test_reset_password_wrong_phase(self, mock_get_data, app: Flask):
"""
Test password reset with token not in reset phase.
@ -442,7 +443,7 @@ class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app: Flask):
"""
Test password reset for non-existent account.

View File

@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -26,10 +27,10 @@ def unwrap(func):
class TestPipelineTemplateListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = PipelineTemplateListApi()
method = unwrap(api.get)
@ -50,10 +51,10 @@ class TestPipelineTemplateListApi:
class TestPipelineTemplateDetailApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -74,7 +75,7 @@ class TestPipelineTemplateDetailApi:
assert status == 200
assert response == template
def test_get_returns_404_when_template_not_found(self, app):
def test_get_returns_404_when_template_not_found(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -93,7 +94,7 @@ class TestPipelineTemplateDetailApi:
assert status == 404
assert "error" in response
def test_get_returns_404_for_customized_type_not_found(self, app):
def test_get_returns_404_for_customized_type_not_found(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -115,10 +116,10 @@ class TestPipelineTemplateDetailApi:
class TestCustomizedPipelineTemplateApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
@ -140,7 +141,7 @@ class TestCustomizedPipelineTemplateApi:
update_mock.assert_called_once()
assert response == 200
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
@ -155,7 +156,7 @@ class TestCustomizedPipelineTemplateApi:
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app, db_session_with_containers: Session):
def test_post_success(self, app: Flask, db_session_with_containers: Session):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
@ -182,7 +183,7 @@ class TestCustomizedPipelineTemplateApi:
assert status == 200
assert response == {"data": "yaml-data"}
def test_post_template_not_found(self, app):
def test_post_template_not_found(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
@ -193,10 +194,10 @@ class TestCustomizedPipelineTemplateApi:
class TestPublishCustomizedPipelineTemplateApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
import services
@ -24,13 +25,13 @@ def unwrap(func):
class TestCreateRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def _valid_payload(self):
return {"yaml_content": "name: test"}
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@ -58,7 +59,7 @@ class TestCreateRagPipelineDatasetApi:
assert status == 201
assert response == import_info
def test_post_forbidden_non_editor(self, app):
def test_post_forbidden_non_editor(self, app: Flask):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@ -76,7 +77,7 @@ class TestCreateRagPipelineDatasetApi:
with pytest.raises(Forbidden):
method(api)
def test_post_dataset_name_duplicate(self, app):
def test_post_dataset_name_duplicate(self, app: Flask):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@ -101,7 +102,7 @@ class TestCreateRagPipelineDatasetApi:
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload(self, app):
def test_post_invalid_payload(self, app: Flask):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@ -122,10 +123,10 @@ class TestCreateRagPipelineDatasetApi:
class TestCreateEmptyRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
@ -152,7 +153,7 @@ class TestCreateEmptyRagPipelineDatasetApi:
assert status == 201
assert response == {"id": "ds-1"}
def test_post_forbidden_non_editor(self, app):
def test_post_forbidden_non_editor(self, app: Flask):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
@ -25,7 +26,7 @@ def unwrap(func):
class TestRagPipelineImportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def _payload(self, mode="create"):
@ -35,7 +36,7 @@ class TestRagPipelineImportApi:
"name": "Test",
}
def test_post_success_200(self, app):
def test_post_success_200(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -65,7 +66,7 @@ class TestRagPipelineImportApi:
assert status == 200
assert response == {"status": "success"}
def test_post_failed_400(self, app):
def test_post_failed_400(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -95,7 +96,7 @@ class TestRagPipelineImportApi:
assert status == 400
assert response == {"status": "failed"}
def test_post_pending_202(self, app):
def test_post_pending_202(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -128,10 +129,10 @@ class TestRagPipelineImportApi:
class TestRagPipelineImportConfirmApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_confirm_success(self, app):
def test_confirm_success(self, app: Flask):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@ -159,7 +160,7 @@ class TestRagPipelineImportConfirmApi:
assert status == 200
assert response == {"ok": True}
def test_confirm_failed(self, app):
def test_confirm_failed(self, app: Flask):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@ -190,10 +191,10 @@ class TestRagPipelineImportConfirmApi:
class TestRagPipelineImportCheckDependenciesApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
@ -219,10 +220,10 @@ class TestRagPipelineImportCheckDependenciesApi:
class TestRagPipelineExportApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_with_include_secret(self, app):
def test_get_with_include_secret(self, app: Flask):
api = RagPipelineExportApi()
method = unwrap(api.get)

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
@ -45,10 +46,10 @@ def unwrap(func):
class TestDraftWorkflowApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_draft_success(self, app):
def test_get_draft_success(self, app: Flask):
api = DraftRagPipelineApi()
method = unwrap(api.get)
@ -68,7 +69,7 @@ class TestDraftWorkflowApi:
result = method(api, pipeline)
assert result == workflow
def test_get_draft_not_exist(self, app):
def test_get_draft_not_exist(self, app: Flask):
api = DraftRagPipelineApi()
method = unwrap(api.get)
@ -86,7 +87,7 @@ class TestDraftWorkflowApi:
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_sync_hash_not_match(self, app):
def test_sync_hash_not_match(self, app: Flask):
api = DraftRagPipelineApi()
method = unwrap(api.post)
@ -111,7 +112,7 @@ class TestDraftWorkflowApi:
with pytest.raises(DraftWorkflowNotSync):
method(api, pipeline)
def test_sync_invalid_text_plain(self, app):
def test_sync_invalid_text_plain(self, app: Flask):
api = DraftRagPipelineApi()
method = unwrap(api.post)
@ -128,7 +129,7 @@ class TestDraftWorkflowApi:
response, status = method(api, pipeline)
assert status == 400
def test_restore_published_workflow_to_draft_success(self, app):
def test_restore_published_workflow_to_draft_success(self, app: Flask):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
@ -155,7 +156,7 @@ class TestDraftWorkflowApi:
assert result["result"] == "success"
assert result["hash"] == "restored-hash"
def test_restore_published_workflow_to_draft_not_found(self, app):
def test_restore_published_workflow_to_draft_not_found(self, app: Flask):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
@ -179,7 +180,7 @@ class TestDraftWorkflowApi:
with pytest.raises(NotFound):
method(api, pipeline, "published-workflow")
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app):
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask):
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
@ -211,10 +212,10 @@ class TestDraftWorkflowApi:
class TestDraftRunNodes:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_iteration_node_success(self, app):
def test_iteration_node_success(self, app: Flask):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
@ -240,7 +241,7 @@ class TestDraftRunNodes:
result = method(api, pipeline, "node")
assert result == {"ok": True}
def test_iteration_node_conversation_not_exists(self, app):
def test_iteration_node_conversation_not_exists(self, app: Flask):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
@ -262,7 +263,7 @@ class TestDraftRunNodes:
with pytest.raises(NotFound):
method(api, pipeline, "node")
def test_loop_node_success(self, app):
def test_loop_node_success(self, app: Flask):
api = RagPipelineDraftRunLoopNodeApi()
method = unwrap(api.post)
@ -290,10 +291,10 @@ class TestDraftRunNodes:
class TestPipelineRunApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_draft_run_success(self, app):
def test_draft_run_success(self, app: Flask):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
@ -325,7 +326,7 @@ class TestPipelineRunApis:
):
assert method(api, pipeline) == {"ok": True}
def test_draft_run_rate_limit(self, app):
def test_draft_run_rate_limit(self, app: Flask):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
@ -356,10 +357,10 @@ class TestPipelineRunApis:
class TestDraftNodeRun:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_execution_not_found(self, app):
def test_execution_not_found(self, app: Flask):
api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post)
@ -387,10 +388,10 @@ class TestDraftNodeRun:
class TestPublishedPipelineApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_publish_success(self, app, db_session_with_containers: Session):
def test_publish_success(self, app: Flask, db_session_with_containers: Session):
from models.dataset import Pipeline
api = PublishedRagPipelineApi()
@ -436,10 +437,10 @@ class TestPublishedPipelineApis:
class TestMiscApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_task_stop(self, app):
def test_task_stop(self, app: Flask):
api = RagPipelineTaskStopApi()
method = unwrap(api.post)
@ -460,7 +461,7 @@ class TestMiscApis:
stop_mock.assert_called_once()
assert result["result"] == "success"
def test_transform_forbidden(self, app):
def test_transform_forbidden(self, app: Flask):
api = RagPipelineTransformApi()
method = unwrap(api.post)
@ -476,7 +477,7 @@ class TestMiscApis:
with pytest.raises(Forbidden):
method(api, "ds1")
def test_recommended_plugins(self, app):
def test_recommended_plugins(self, app: Flask):
api = RagPipelineRecommendedPluginApi()
method = unwrap(api.get)
@ -496,10 +497,10 @@ class TestMiscApis:
class TestPublishedRagPipelineRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_published_run_success(self, app):
def test_published_run_success(self, app: Flask):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
@ -533,7 +534,7 @@ class TestPublishedRagPipelineRunApi:
result = method(api, pipeline)
assert result == {"ok": True}
def test_published_run_rate_limit(self, app):
def test_published_run_rate_limit(self, app: Flask):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
@ -565,10 +566,10 @@ class TestPublishedRagPipelineRunApi:
class TestDefaultBlockConfigApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_block_config_success(self, app):
def test_get_block_config_success(self, app: Flask):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
@ -587,7 +588,7 @@ class TestDefaultBlockConfigApi:
result = method(api, pipeline, "llm")
assert result == {"k": "v"}
def test_get_block_config_invalid_json(self, app):
def test_get_block_config_invalid_json(self, app: Flask):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
@ -600,10 +601,10 @@ class TestDefaultBlockConfigApi:
class TestPublishedAllRagPipelineApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_published_workflows_success(self, app):
def test_get_published_workflows_success(self, app: Flask):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
@ -629,7 +630,7 @@ class TestPublishedAllRagPipelineApi:
assert result["items"] == [{"id": "w1"}]
assert result["has_more"] is False
def test_get_published_workflows_forbidden(self, app):
def test_get_published_workflows_forbidden(self, app: Flask):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
@ -649,10 +650,10 @@ class TestPublishedAllRagPipelineApi:
class TestRagPipelineByIdApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
@ -682,7 +683,7 @@ class TestRagPipelineByIdApi:
assert result == workflow
def test_patch_no_fields(self, app):
def test_patch_no_fields(self, app: Flask):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
@ -700,7 +701,7 @@ class TestRagPipelineByIdApi:
result, status = method(api, pipeline, "w1")
assert status == 400
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = RagPipelineByIdApi()
method = unwrap(api.delete)
@ -720,7 +721,7 @@ class TestRagPipelineByIdApi:
workflow_service.delete_workflow.assert_called_once()
assert result == (None, 204)
def test_delete_active_workflow_rejected(self, app):
def test_delete_active_workflow_rejected(self, app: Flask):
api = RagPipelineByIdApi()
method = unwrap(api.delete)
@ -733,10 +734,10 @@ class TestRagPipelineByIdApi:
class TestRagPipelineWorkflowLastRunApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_last_run_success(self, app):
def test_last_run_success(self, app: Flask):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
@ -758,7 +759,7 @@ class TestRagPipelineWorkflowLastRunApi:
result = method(api, pipeline, "node1")
assert result == node_exec
def test_last_run_not_found(self, app):
def test_last_run_not_found(self, app: Flask):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
@ -780,10 +781,10 @@ class TestRagPipelineWorkflowLastRunApi:
class TestRagPipelineDatasourceVariableApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_set_datasource_variables_success(self, app):
def test_set_datasource_variables_success(self, app: Flask):
api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console.datasets import data_source
@ -51,10 +52,10 @@ def mock_engine():
class TestDataSourceApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@ -78,7 +79,7 @@ class TestDataSourceApi:
assert status == 200
assert response["data"][0]["is_bound"] is True
def test_get_no_bindings(self, app, patch_tenant):
def test_get_no_bindings(self, app: Flask, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@ -94,7 +95,7 @@ class TestDataSourceApi:
assert status == 200
assert response["data"] == []
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -115,7 +116,7 @@ class TestDataSourceApi:
assert status == 200
assert binding.disabled is False
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -136,7 +137,7 @@ class TestDataSourceApi:
assert status == 200
assert binding.disabled is True
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -151,7 +152,7 @@ class TestDataSourceApi:
with pytest.raises(NotFound):
method(api, "b1", "enable")
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -168,7 +169,7 @@ class TestDataSourceApi:
with pytest.raises(ValueError):
method(api, "b1", "enable")
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -188,10 +189,10 @@ class TestDataSourceApi:
class TestDataSourceNotionListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_credential_not_found(self, app, patch_tenant):
def test_get_credential_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -205,7 +206,7 @@ class TestDataSourceNotionListApi:
with pytest.raises(NotFound):
method(api)
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -246,7 +247,7 @@ class TestDataSourceNotionListApi:
assert status == 200
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -299,7 +300,7 @@ class TestDataSourceNotionListApi:
assert status == 200
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -323,10 +324,10 @@ class TestDataSourceNotionListApi:
class TestDataSourceNotionApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_preview_success(self, app, patch_tenant):
def test_get_preview_success(self, app: Flask, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.get)
@ -347,7 +348,7 @@ class TestDataSourceNotionApi:
assert status == 200
def test_post_indexing_estimate_success(self, app, patch_tenant):
def test_post_indexing_estimate_success(self, app: Flask, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.post)
@ -381,10 +382,10 @@ class TestDataSourceNotionApi:
class TestDataSourceNotionDatasetSyncApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@ -407,7 +408,7 @@ class TestDataSourceNotionDatasetSyncApi:
assert status == 200
def test_get_dataset_not_found(self, app, patch_tenant):
def test_get_dataset_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@ -424,10 +425,10 @@ class TestDataSourceNotionDatasetSyncApi:
class TestDataSourceNotionDocumentSyncApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
@ -450,7 +451,7 @@ class TestDataSourceNotionDocumentSyncApi:
assert status == 200
def test_get_document_not_found(self, app, patch_tenant):
def test_get_document_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.conversation as conversation_module
@ -53,10 +54,10 @@ def user():
class TestConversationListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, chat_app, user):
def test_get_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -81,7 +82,7 @@ class TestConversationListApi:
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app, chat_app, user):
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -97,7 +98,7 @@ class TestConversationListApi:
with pytest.raises(NotFound):
method(chat_app)
def test_wrong_app_mode(self, app, non_chat_app):
def test_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -108,10 +109,10 @@ class TestConversationListApi:
class TestConversationApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_delete_success(self, app, chat_app, user):
def test_delete_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -129,7 +130,7 @@ class TestConversationApi:
assert status == 204
assert body["result"] == "success"
def test_delete_not_found(self, app, chat_app, user):
def test_delete_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -145,7 +146,7 @@ class TestConversationApi:
with pytest.raises(NotFound):
method(chat_app, "cid")
def test_delete_wrong_app_mode(self, app, non_chat_app):
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -156,10 +157,10 @@ class TestConversationApi:
class TestConversationRenameApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_rename_success(self, app, chat_app, user):
def test_rename_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@ -178,7 +179,7 @@ class TestConversationRenameApi:
assert result["id"] == "cid"
def test_rename_not_found(self, app, chat_app, user):
def test_rename_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@ -197,10 +198,10 @@ class TestConversationRenameApi:
class TestConversationPinApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_pin_success(self, app, chat_app, user):
def test_pin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
@ -219,10 +220,10 @@ class TestConversationPinApi:
class TestConversationUnPinApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_unpin_success(self, app, chat_app, user):
def test_unpin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)

View File

@ -6,6 +6,7 @@ import json
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.tool_providers import (
@ -60,7 +61,7 @@ def _mock_user_tenant():
@pytest.fixture
def client(flask_app_with_containers):
def client(flask_app_with_containers: Flask):
return flask_app_with_containers.test_client()
@ -147,10 +148,10 @@ class TestUtils:
class TestToolProviderListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ToolProviderListApi()
method = unwrap(api.get)
@ -170,10 +171,10 @@ class TestToolProviderListApi:
class TestBuiltinProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_list_tools(self, app):
def test_list_tools(self, app: Flask):
api = ToolBuiltinProviderListToolsApi()
method = unwrap(api.get)
@ -190,7 +191,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider") == [{"a": 1}]
def test_info(self, app):
def test_info(self, app: Flask):
api = ToolBuiltinProviderInfoApi()
method = unwrap(api.get)
@ -207,7 +208,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider") == {"x": 1}
def test_delete(self, app):
def test_delete(self, app: Flask):
api = ToolBuiltinProviderDeleteApi()
method = unwrap(api.post)
@ -224,7 +225,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider")["result"] == "success"
def test_add_invalid_type(self, app):
def test_add_invalid_type(self, app: Flask):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
@ -238,7 +239,7 @@ class TestBuiltinProviderApis:
with pytest.raises(ValueError):
method(api, "provider")
def test_add_success(self, app):
def test_add_success(self, app: Flask):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
@ -257,7 +258,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider")["id"] == 1
def test_update(self, app):
def test_update(self, app: Flask):
api = ToolBuiltinProviderUpdateApi()
method = unwrap(api.post)
@ -276,7 +277,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider")["ok"]
def test_get_credentials(self, app):
def test_get_credentials(self, app: Flask):
api = ToolBuiltinProviderGetCredentialsApi()
method = unwrap(api.get)
@ -293,7 +294,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider") == {"k": "v"}
def test_icon(self, app):
def test_icon(self, app: Flask):
api = ToolBuiltinProviderIconApi()
method = unwrap(api.get)
@ -307,7 +308,7 @@ class TestBuiltinProviderApis:
response = method(api, "provider")
assert response.mimetype == "image/png"
def test_credentials_schema(self, app):
def test_credentials_schema(self, app: Flask):
api = ToolBuiltinProviderCredentialsSchemaApi()
method = unwrap(api.get)
@ -324,7 +325,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider", "oauth2") == {"schema": {}}
def test_set_default_credential(self, app):
def test_set_default_credential(self, app: Flask):
api = ToolBuiltinProviderSetDefaultApi()
method = unwrap(api.post)
@ -341,7 +342,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider")["ok"]
def test_get_credential_info(self, app):
def test_get_credential_info(self, app: Flask):
api = ToolBuiltinProviderGetCredentialInfoApi()
method = unwrap(api.get)
@ -358,7 +359,7 @@ class TestBuiltinProviderApis:
):
assert method(api, "provider") == {"info": "x"}
def test_get_oauth_client_schema(self, app):
def test_get_oauth_client_schema(self, app: Flask):
api = ToolBuiltinProviderGetOauthClientSchemaApi()
method = unwrap(api.get)
@ -378,10 +379,10 @@ class TestBuiltinProviderApis:
class TestApiProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_add(self, app):
def test_add(self, app: Flask):
api = ToolApiProviderAddApi()
method = unwrap(api.post)
@ -406,7 +407,7 @@ class TestApiProviderApis:
):
assert method(api)["id"] == 1
def test_remote_schema(self, app):
def test_remote_schema(self, app: Flask):
api = ToolApiProviderGetRemoteSchemaApi()
method = unwrap(api.get)
@ -423,7 +424,7 @@ class TestApiProviderApis:
):
assert method(api)["schema"] == "x"
def test_list_tools(self, app):
def test_list_tools(self, app: Flask):
api = ToolApiProviderListToolsApi()
method = unwrap(api.get)
@ -440,7 +441,7 @@ class TestApiProviderApis:
):
assert method(api) == [{"tool": 1}]
def test_update(self, app):
def test_update(self, app: Flask):
api = ToolApiProviderUpdateApi()
method = unwrap(api.post)
@ -468,7 +469,7 @@ class TestApiProviderApis:
):
assert method(api)["ok"]
def test_delete(self, app):
def test_delete(self, app: Flask):
api = ToolApiProviderDeleteApi()
method = unwrap(api.post)
@ -485,7 +486,7 @@ class TestApiProviderApis:
):
assert method(api)["result"] == "success"
def test_get(self, app):
def test_get(self, app: Flask):
api = ToolApiProviderGetApi()
method = unwrap(api.get)
@ -505,10 +506,10 @@ class TestApiProviderApis:
class TestWorkflowApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_create(self, app):
def test_create(self, app: Flask):
api = ToolWorkflowProviderCreateApi()
method = unwrap(api.post)
@ -534,7 +535,7 @@ class TestWorkflowApis:
):
assert method(api)["id"] == 1
def test_update_invalid(self, app):
def test_update_invalid(self, app: Flask):
api = ToolWorkflowProviderUpdateApi()
method = unwrap(api.post)
@ -560,7 +561,7 @@ class TestWorkflowApis:
result = method(api)
assert result["ok"]
def test_delete(self, app):
def test_delete(self, app: Flask):
api = ToolWorkflowProviderDeleteApi()
method = unwrap(api.post)
@ -577,7 +578,7 @@ class TestWorkflowApis:
):
assert method(api)["ok"]
def test_get_error(self, app):
def test_get_error(self, app: Flask):
api = ToolWorkflowProviderGetApi()
method = unwrap(api.get)
@ -594,10 +595,10 @@ class TestWorkflowApis:
class TestLists:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_builtin_list(self, app):
def test_builtin_list(self, app: Flask):
api = ToolBuiltinListApi()
method = unwrap(api.get)
@ -617,7 +618,7 @@ class TestLists:
):
assert method(api) == [{"x": 1}]
def test_api_list(self, app):
def test_api_list(self, app: Flask):
api = ToolApiListApi()
method = unwrap(api.get)
@ -637,7 +638,7 @@ class TestLists:
):
assert method(api) == [{"x": 1}]
def test_workflow_list(self, app):
def test_workflow_list(self, app: Flask):
api = ToolWorkflowListApi()
method = unwrap(api.get)
@ -660,10 +661,10 @@ class TestLists:
class TestLabels:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_labels(self, app):
def test_labels(self, app: Flask):
api = ToolLabelsApi()
method = unwrap(api.get)
@ -679,10 +680,10 @@ class TestLabels:
class TestOAuth:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_oauth_no_client(self, app):
def test_oauth_no_client(self, app: Flask):
api = ToolPluginOAuthApi()
method = unwrap(api.get)
@ -700,7 +701,7 @@ class TestOAuth:
with pytest.raises(Forbidden):
method(api, "provider")
def test_oauth_callback_no_cookie(self, app):
def test_oauth_callback_no_cookie(self, app: Flask):
api = ToolOAuthCallback()
method = unwrap(api.get)
@ -711,10 +712,10 @@ class TestOAuth:
class TestOAuthCustomClient:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_save_custom_client(self, app):
def test_save_custom_client(self, app: Flask):
api = ToolOAuthCustomClient()
method = unwrap(api.post)
@ -731,7 +732,7 @@ class TestOAuthCustomClient:
):
assert method(api, "provider")["ok"]
def test_get_custom_client(self, app):
def test_get_custom_client(self, app: Flask):
api = ToolOAuthCustomClient()
method = unwrap(api.get)
@ -748,7 +749,7 @@ class TestOAuthCustomClient:
):
assert method(api, "provider") == {"client_id": "x"}
def test_delete_custom_client(self, app):
def test_delete_custom_client(self, app: Flask):
api = ToolOAuthCustomClient()
method = unwrap(api.delete)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, Forbidden
from controllers.console.workspace.trigger_providers import (
@ -45,10 +46,10 @@ def mock_user():
class TestTriggerProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_icon_success(self, app):
def test_icon_success(self, app: Flask):
api = TriggerProviderIconApi()
method = unwrap(api.get)
@ -62,7 +63,7 @@ class TestTriggerProviderApis:
):
assert method(api, "github") == "icon"
def test_list_providers(self, app):
def test_list_providers(self, app: Flask):
api = TriggerProviderListApi()
method = unwrap(api.get)
@ -76,7 +77,7 @@ class TestTriggerProviderApis:
):
assert method(api) == []
def test_provider_info(self, app):
def test_provider_info(self, app: Flask):
api = TriggerProviderInfoApi()
method = unwrap(api.get)
@ -93,10 +94,10 @@ class TestTriggerProviderApis:
class TestTriggerSubscriptionListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@ -110,7 +111,7 @@ class TestTriggerSubscriptionListApi:
):
assert method(api, "github") == []
def test_list_invalid_provider(self, app):
def test_list_invalid_provider(self, app: Flask):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@ -128,10 +129,10 @@ class TestTriggerSubscriptionListApi:
class TestTriggerSubscriptionBuilderApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_create_builder(self, app):
def test_create_builder(self, app: Flask):
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
@ -146,7 +147,7 @@ class TestTriggerSubscriptionBuilderApis:
result = method(api, "github")
assert "subscription_builder" in result
def test_get_builder(self, app):
def test_get_builder(self, app: Flask):
api = TriggerSubscriptionBuilderGetApi()
method = unwrap(api.get)
@ -159,7 +160,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_verify_builder(self, app):
def test_verify_builder(self, app: Flask):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
@ -173,7 +174,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"ok": True}
def test_verify_builder_error(self, app):
def test_verify_builder_error(self, app: Flask):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
@ -188,7 +189,7 @@ class TestTriggerSubscriptionBuilderApis:
with pytest.raises(ValueError):
method(api, "github", "b1")
def test_update_builder(self, app):
def test_update_builder(self, app: Flask):
api = TriggerSubscriptionBuilderUpdateApi()
method = unwrap(api.post)
@ -202,7 +203,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_logs(self, app):
def test_logs(self, app: Flask):
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
@ -219,7 +220,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert "logs" in method(api, "github", "b1")
def test_build(self, app):
def test_build(self, app: Flask):
api = TriggerSubscriptionBuilderBuildApi()
method = unwrap(api.post)
@ -236,10 +237,10 @@ class TestTriggerSubscriptionBuilderApis:
class TestTriggerSubscriptionCrud:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_update_rename_only(self, app):
def test_update_rename_only(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -258,7 +259,7 @@ class TestTriggerSubscriptionCrud:
):
assert method(api, "s1") == 200
def test_update_not_found(self, app):
def test_update_not_found(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -273,7 +274,7 @@ class TestTriggerSubscriptionCrud:
with pytest.raises(NotFoundError):
method(api, "x")
def test_update_rebuild(self, app):
def test_update_rebuild(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -296,7 +297,7 @@ class TestTriggerSubscriptionCrud:
):
assert method(api, "s1") == 200
def test_delete_subscription(self, app):
def test_delete_subscription(self, app: Flask):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -319,7 +320,7 @@ class TestTriggerSubscriptionCrud:
assert result["result"] == "success"
def test_delete_subscription_value_error(self, app):
def test_delete_subscription_value_error(self, app: Flask):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -342,10 +343,10 @@ class TestTriggerSubscriptionCrud:
class TestTriggerOAuthApis:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_oauth_authorize_success(self, app):
def test_oauth_authorize_success(self, app: Flask):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@ -372,7 +373,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 200
def test_oauth_authorize_no_client(self, app):
def test_oauth_authorize_no_client(self, app: Flask):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@ -387,7 +388,7 @@ class TestTriggerOAuthApis:
with pytest.raises(NotFoundError):
method(api, "github")
def test_oauth_callback_forbidden(self, app):
def test_oauth_callback_forbidden(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -395,7 +396,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_success(self, app):
def test_oauth_callback_success(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -425,7 +426,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 302
def test_oauth_callback_no_oauth_client(self, app):
def test_oauth_callback_no_oauth_client(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -449,7 +450,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_empty_credentials(self, app):
def test_oauth_callback_empty_credentials(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -480,10 +481,10 @@ class TestTriggerOAuthApis:
class TestTriggerOAuthClientManageApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_client(self, app):
def test_get_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
@ -510,7 +511,7 @@ class TestTriggerOAuthClientManageApi:
result = method(api, "github")
assert "configured" in result
def test_post_client(self, app):
def test_post_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
@ -524,7 +525,7 @@ class TestTriggerOAuthClientManageApi:
):
assert method(api, "github") == {"ok": True}
def test_delete_client(self, app):
def test_delete_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.delete)
@ -538,7 +539,7 @@ class TestTriggerOAuthClientManageApi:
):
assert method(api, "github") == {"ok": True}
def test_oauth_client_post_value_error(self, app):
def test_oauth_client_post_value_error(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
@ -556,10 +557,10 @@ class TestTriggerOAuthClientManageApi:
class TestTriggerSubscriptionVerifyApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_verify_success(self, app):
def test_verify_success(self, app: Flask):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)

View File

@ -18,6 +18,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
@ -217,10 +218,20 @@ class TestTagUnbindingPayload:
"""Test suite for TagUnbindingPayload Pydantic model."""
def test_payload_with_valid_data(self):
payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456")
assert payload.tag_id == "tag_123"
payload = TagUnbindingPayload(tag_ids=["tag_123"], target_id="dataset_456")
assert payload.tag_ids == ["tag_123"]
assert payload.target_id == "dataset_456"
def test_payload_normalizes_legacy_tag_id(self):
payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456")
assert payload.tag_ids == ["tag_123"]
assert payload.target_id == "dataset_456"
def test_payload_rejects_empty_tag_ids(self):
with pytest.raises(ValueError) as exc_info:
TagUnbindingPayload(tag_ids=[], target_id="dataset_456")
assert "Tag IDs is required" in str(exc_info.value)
# ---------------------------------------------------------------------------
# Helpers
@ -236,7 +247,7 @@ def _unwrap(method):
@pytest.fixture
def app(flask_app_with_containers):
def app(flask_app_with_containers: Flask):
# Uses the full containerised app so that Flask config, extensions, and
# blueprint registrations match production. Most tests mock the service
# layer to isolate controller logic; a few (e.g. test_list_tags_from_db)
@ -280,7 +291,7 @@ class TestDatasetListApiGet:
mock_current_user,
mock_provider_mgr,
mock_marshal,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -315,7 +326,7 @@ class TestDatasetListApiPost:
mock_dataset_svc,
mock_current_user,
mock_marshal,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -341,7 +352,7 @@ class TestDatasetListApiPost:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -379,7 +390,7 @@ class TestDatasetApiGet:
mock_provider_mgr,
mock_marshal,
mock_perm_svc,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -429,7 +440,7 @@ class TestDatasetApiGet:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -457,7 +468,7 @@ class TestDatasetApiDelete:
mock_dataset_svc,
mock_current_user,
mock_perm_svc,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -479,7 +490,7 @@ class TestDatasetApiDelete:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -500,7 +511,7 @@ class TestDatasetApiDelete:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -532,7 +543,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -563,7 +574,7 @@ class TestDocumentStatusApiPatch:
def test_batch_update_status_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -592,7 +603,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -625,7 +636,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -658,7 +669,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -698,7 +709,7 @@ class TestDatasetTagsApiGet:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -720,7 +731,7 @@ class TestDatasetTagsApiGet:
def test_list_tags_from_db(
self,
mock_current_user,
app,
app: Flask,
db_session_with_containers: Session,
):
"""Integration test: creates real Tag rows and retrieves them
@ -763,7 +774,7 @@ class TestDatasetTagsApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -786,7 +797,7 @@ class TestDatasetTagsApiPost:
mock_tag_svc.save_tags.assert_called_once()
@patch("controllers.service_api.dataset.dataset.current_user")
def test_create_tag_forbidden(self, mock_current_user, app):
def test_create_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
mock_current_user.__class__ = Account
@ -815,7 +826,7 @@ class TestDatasetTagsApiPatch:
mock_current_user,
mock_service_api_ns,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -841,7 +852,7 @@ class TestDatasetTagsApiPatch:
mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1")
@patch("controllers.service_api.dataset.dataset.current_user")
def test_update_tag_forbidden(self, mock_current_user, app):
def test_update_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
mock_current_user.__class__ = Account
@ -869,7 +880,7 @@ class TestDatasetTagsApiDelete:
mock_current_user,
mock_service_api_ns,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -894,7 +905,7 @@ class TestDatasetTagsApiDelete:
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
@patch("libs.login.current_user")
def test_delete_tag_forbidden(self, mock_current_user, app):
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
user_obj = Mock(spec=Account)
@ -922,7 +933,7 @@ class TestDatasetTagsBindingStatusApi:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi
@ -952,7 +963,7 @@ class TestDatasetTagBindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
@ -977,7 +988,7 @@ class TestDatasetTagBindingApiPost:
)
@patch("controllers.service_api.dataset.dataset.current_user")
def test_bind_tags_forbidden(self, mock_current_user, app):
def test_bind_tags_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
mock_current_user.__class__ = Account
@ -1003,7 +1014,37 @@ class TestDatasetTagUnbindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
mock_current_user.__class__ = Account
mock_current_user.has_edit_permission = True
mock_current_user.is_dataset_editor = True
mock_tag_svc.delete_tag_binding.return_value = None
with app.test_request_context(
"/datasets/tags/unbinding",
method="POST",
json={"tag_ids": ["tag-1"], "target_id": "ds-1"},
):
api = DatasetTagUnbindingApi()
result = api.post(_=None)
assert result == ("", 204)
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with(
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge")
)
@patch("controllers.service_api.dataset.dataset.TagService")
@patch("controllers.service_api.dataset.dataset.current_user")
def test_unbind_legacy_tag_id_success(
self,
mock_current_user,
mock_tag_svc,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
@ -1024,11 +1065,11 @@ class TestDatasetTagUnbindingApiPost:
from services.tag_service import TagBindingDeletePayload
mock_tag_svc.delete_tag_binding.assert_called_once_with(
TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge")
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge")
)
@patch("controllers.service_api.dataset.dataset.current_user")
def test_unbind_tag_forbidden(self, mock_current_user, app):
def test_unbind_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
mock_current_user.__class__ = Account
@ -1038,7 +1079,7 @@ class TestDatasetTagUnbindingApiPost:
with app.test_request_context(
"/datasets/tags/unbinding",
method="POST",
json={"tag_id": "tag-1", "target_id": "ds-1"},
json={"tag_ids": ["tag-1"], "target_id": "ds-1"},
):
api = DatasetTagUnbindingApi()
with pytest.raises(Forbidden):

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.web.conversation import (
@ -34,16 +35,16 @@ def _end_user() -> SimpleNamespace:
class TestConversationListApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context("/conversations"):
with pytest.raises(NotChatAppError):
ConversationListApi().get(_completion_app(), _end_user())
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
def test_happy_path(self, mock_paginate: MagicMock, app) -> None:
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
conv_id = str(uuid4())
conv = SimpleNamespace(
id=conv_id,
@ -65,16 +66,16 @@ class TestConversationListApi:
class TestConversationApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}"):
with pytest.raises(NotChatAppError):
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.delete")
def test_delete_success(self, mock_delete: MagicMock, app) -> None:
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
@ -83,7 +84,7 @@ class TestConversationApi:
assert result["result"] == "success"
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
def test_delete_not_found(self, mock_delete: MagicMock, app) -> None:
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}"):
with pytest.raises(NotFound, match="Conversation Not Exists"):
@ -92,17 +93,17 @@ class TestConversationApi:
class TestConversationRenameApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
with pytest.raises(NotChatAppError):
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.ConversationService.rename")
@patch("controllers.web.conversation.web_ns")
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "New Name", "auto_generate": False}
conv = SimpleNamespace(
@ -126,7 +127,7 @@ class TestConversationRenameApi:
side_effect=ConversationNotExistsError(),
)
@patch("controllers.web.conversation.web_ns")
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
c_id = uuid4()
mock_ns.payload = {"name": "X", "auto_generate": False}
@ -137,16 +138,16 @@ class TestConversationRenameApi:
class TestConversationPinApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.pin")
def test_pin_success(self, mock_pin: MagicMock, app) -> None:
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
@ -154,7 +155,7 @@ class TestConversationPinApi:
assert result["result"] == "success"
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
def test_pin_not_found(self, mock_pin: MagicMock, app) -> None:
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
with pytest.raises(NotFound):
@ -163,16 +164,16 @@ class TestConversationPinApi:
class TestConversationUnPinApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_non_chat_mode_raises(self, app) -> None:
def test_non_chat_mode_raises(self, app: Flask) -> None:
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
with pytest.raises(NotChatAppError):
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
@patch("controllers.web.conversation.WebConversationService.unpin")
def test_unpin_success(self, mock_unpin: MagicMock, app) -> None:
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
c_id = uuid4()
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)

View File

@ -7,6 +7,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.web.forgot_password import (
ForgotPasswordCheckApi,
@ -29,7 +30,7 @@ def _patch_wraps():
class TestForgotPasswordSendEmailApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
@ -42,7 +43,7 @@ class TestForgotPasswordSendEmailApi:
mock_rate_limit,
mock_get_account,
mock_send_mail,
app,
app: Flask,
):
mock_account = MagicMock()
mock_get_account.return_value = mock_account
@ -64,7 +65,7 @@ class TestForgotPasswordSendEmailApi:
class TestForgotPasswordCheckApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
@ -81,7 +82,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
app: Flask,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"}
@ -117,7 +118,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
app,
app: Flask,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"}
@ -142,7 +143,7 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
@ -157,7 +158,7 @@ class TestForgotPasswordResetApi:
mock_db,
mock_get_account,
mock_update_account,
app,
app: Flask,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
mock_account = MagicMock()
@ -194,7 +195,7 @@ class TestForgotPasswordResetApi:
mock_db,
mock_token_bytes,
mock_hash_password,
app,
app: Flask,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
account = MagicMock()

View File

@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
@ -182,7 +183,7 @@ class TestValidateUserAccessibility:
class TestDecodeJwtToken:
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True):
@ -239,7 +240,7 @@ class TestDecodeJwtToken:
mock_access_mode: MagicMock,
mock_validate_token: MagicMock,
mock_validate_user: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
@ -299,7 +300,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False)
@ -324,7 +325,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, _ = self._create_app_site_enduser(db_session_with_containers)
@ -350,7 +351,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)

View File

@ -85,7 +85,7 @@ class TestPauseStatePersistenceLayerTestContainers:
return WorkflowRunService(engine)
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
def setup_test_data(self, db_session_with_containers: Session, file_service, workflow_run_service):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account
from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
@ -295,7 +295,7 @@ class TestPauseStatePersistenceLayerTestContainers:
generate_entity=entity,
)
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers: Session):
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -352,7 +352,7 @@ class TestPauseStatePersistenceLayerTestContainers:
assert isinstance(persisted_entity, WorkflowAppGenerateEntity)
assert persisted_entity.workflow_execution_id == self.test_workflow_run_id
def test_state_persistence_and_retrieval(self, db_session_with_containers):
def test_state_persistence_and_retrieval(self, db_session_with_containers: Session):
"""Test that pause state can be persisted and retrieved correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -402,7 +402,7 @@ class TestPauseStatePersistenceLayerTestContainers:
assert retrieved_state["node_run_steps"] == 10
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_database_transaction_handling(self, db_session_with_containers):
def test_database_transaction_handling(self, db_session_with_containers: Session):
"""Test that database transactions are handled correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -433,7 +433,7 @@ class TestPauseStatePersistenceLayerTestContainers:
assert pause_model.resumed_at is None
assert pause_model.state_object_key != ""
def test_file_storage_integration(self, db_session_with_containers):
def test_file_storage_integration(self, db_session_with_containers: Session):
"""Test integration with file storage system."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -467,7 +467,7 @@ class TestPauseStatePersistenceLayerTestContainers:
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
def test_workflow_with_different_creators(self, db_session_with_containers):
def test_workflow_with_different_creators(self, db_session_with_containers: Session):
"""Test pause state with workflows created by different users."""
# Arrange - Create workflow with different creator
different_user_id = str(uuid.uuid4())
@ -532,7 +532,7 @@ class TestPauseStatePersistenceLayerTestContainers:
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
def test_layer_ignores_non_pause_events(self, db_session_with_containers: Session):
"""Test that layer ignores non-pause events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
@ -562,7 +562,7 @@ class TestPauseStatePersistenceLayerTestContainers:
).all()
assert len(pause_states) == 0
def test_layer_requires_initialization(self, db_session_with_containers):
def test_layer_requires_initialization(self, db_session_with_containers: Session):
"""Test that layer requires proper initialization before handling events."""
# Arrange
layer = self._create_pause_state_persistence_layer()

View File

@ -15,11 +15,14 @@ from uuid import uuid4
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
from extensions.ext_redis import redis_client
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
TenantAndAccount = tuple[Tenant, Account]
@dataclass
class TestTask:
@ -40,7 +43,7 @@ class TestTenantIsolatedTaskQueueIntegration:
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker):
"""Create test tenant and account for testing."""
# Create account
account = Account(
@ -73,18 +76,18 @@ class TestTenantIsolatedTaskQueueIntegration:
return tenant, account
@pytest.fixture
def test_queue(self, test_tenant_and_account):
def test_queue(self, test_tenant_and_account: TenantAndAccount):
"""Create a generic test queue for testing."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
@pytest.fixture
def secondary_queue(self, test_tenant_and_account):
def secondary_queue(self, test_tenant_and_account: TenantAndAccount):
"""Create a secondary test queue for testing isolation."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
def test_queue_initialization(self, test_tenant_and_account):
def test_queue_initialization(self, test_tenant_and_account: TenantAndAccount):
"""Test queue initialization with correct key generation."""
tenant, _ = test_tenant_and_account
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
@ -94,7 +97,9 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
def test_tenant_isolation(
self, test_tenant_and_account: TenantAndAccount, db_session_with_containers: Session, fake: Faker
):
"""Test that different tenants have isolated queues."""
tenant1, _ = test_tenant_and_account
@ -114,7 +119,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
def test_key_isolation(self, test_tenant_and_account):
def test_key_isolation(self, test_tenant_and_account: TenantAndAccount):
"""Test that different keys have isolated queues."""
tenant, _ = test_tenant_and_account
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
@ -176,7 +181,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert len(remaining_tasks) == 2
assert remaining_tasks == ["task4", "task5"]
def test_push_and_pull_complex_objects(self, test_queue, fake):
def test_push_and_pull_complex_objects(self, test_queue, fake: Faker):
"""Test pushing and pulling complex object tasks."""
# Create complex task objects as dictionaries (not dataclass instances)
tasks = [
@ -218,7 +223,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert pulled_task["data"] == original_task["data"]
assert pulled_task["metadata"] == original_task["metadata"]
def test_mixed_task_types(self, test_queue, fake):
def test_mixed_task_types(self, test_queue, fake: Faker):
"""Test pushing and pulling mixed string and object tasks."""
string_task = "simple_string_task"
object_task = {
@ -267,7 +272,7 @@ class TestTenantIsolatedTaskQueueIntegration:
# Verify task key has expired
assert test_queue.get_task_key() is None
def test_large_task_batch(self, test_queue, fake):
def test_large_task_batch(self, test_queue, fake: Faker):
"""Test handling large batches of tasks."""
# Create large batch of tasks
large_batch = []
@ -292,7 +297,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert isinstance(task, dict)
assert task["index"] == i # FIFO order
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
def test_queue_operations_isolation(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""Test concurrent operations on different queues."""
tenant, _ = test_tenant_and_account
@ -312,7 +317,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert tasks2 == ["task1_queue2", "task2_queue2"]
assert tasks1 != tasks2
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake: Faker):
"""Test TaskWrapper serialization and deserialization roundtrip."""
# Create complex nested data
complex_data = {
@ -346,7 +351,7 @@ class TestTenantIsolatedTaskQueueIntegration:
task = test_queue.pull_tasks(1)
assert task[0] == invalid_json_task
def test_real_world_batch_processing_scenario(self, test_queue, fake):
def test_real_world_batch_processing_scenario(self, test_queue, fake: Faker):
"""Test realistic batch processing scenario."""
# Simulate batch processing tasks
batch_tasks = []
@ -403,7 +408,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker):
"""Create test tenant and account for testing."""
# Create account
account = Account(
@ -435,7 +440,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
return tenant, account
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
def test_legacy_string_queue_compatibility(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test compatibility with legacy queues containing only string data.
@ -465,7 +470,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
assert pulled_tasks == expected_order
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
def test_legacy_queue_migration_scenario(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test complete migration scenario from legacy to new system.
@ -546,7 +551,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
assert task["tenant_id"] == tenant.id
assert task["processing_type"] == "new_system"
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
def test_legacy_queue_error_recovery(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test error recovery when legacy queue contains malformed data.

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -15,7 +16,7 @@ from tests.test_containers_integration_tests.helpers import generate_valid_passw
class TestGetAvailableDatasetsIntegration:
def test_returns_datasets_with_available_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -77,7 +78,7 @@ class TestGetAvailableDatasetsIntegration:
assert result[0].name == dataset.name
def test_filters_out_datasets_with_only_archived_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -130,7 +131,7 @@ class TestGetAvailableDatasetsIntegration:
assert len(result) == 0
def test_filters_out_datasets_with_only_disabled_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -183,7 +184,7 @@ class TestGetAvailableDatasetsIntegration:
assert len(result) == 0
def test_filters_out_datasets_with_non_completed_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -236,7 +237,7 @@ class TestGetAvailableDatasetsIntegration:
assert len(result) == 0
def test_includes_external_datasets_without_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that external datasets are returned even with no available documents.
@ -280,7 +281,7 @@ class TestGetAvailableDatasetsIntegration:
assert result[0].id == dataset.id
assert result[0].provider == "external"
def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
def test_filters_by_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies):
# Arrange
fake = Faker()
@ -356,7 +357,7 @@ class TestGetAvailableDatasetsIntegration:
assert result[0].tenant_id == tenant1.id
def test_returns_empty_list_when_no_datasets_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -379,7 +380,9 @@ class TestGetAvailableDatasetsIntegration:
# Assert
assert result == []
def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies):
def test_returns_only_requested_dataset_ids(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -439,7 +442,7 @@ class TestGetAvailableDatasetsIntegration:
class TestKnowledgeRetrievalIntegration:
def test_knowledge_retrieval_with_available_datasets(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -507,7 +510,7 @@ class TestKnowledgeRetrievalIntegration:
assert isinstance(result, list)
def test_knowledge_retrieval_no_available_datasets(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()
@ -555,7 +558,7 @@ class TestKnowledgeRetrievalIntegration:
assert result == []
def test_knowledge_retrieval_rate_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
# Arrange
fake = Faker()

View File

@ -1,4 +1,5 @@
import unittest
from __future__ import annotations
from datetime import UTC, datetime
from unittest.mock import patch
from uuid import uuid4
@ -16,7 +17,7 @@ from models.enums import CreatorUserRole
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
class TestStorageKeyLoader(unittest.TestCase):
class TestStorageKeyLoader:
"""
Integration tests for StorageKeyLoader class.
@ -24,110 +25,82 @@ class TestStorageKeyLoader(unittest.TestCase):
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
"""
def setUp(self):
"""Set up test data before each test method."""
self.session = db.session()
self.tenant_id = str(uuid4())
self.user_id = str(uuid4())
self.conversation_id = str(uuid4())
# Create test data that will be cleaned up after each test
self.test_upload_files = []
self.test_tool_files = []
# Create StorageKeyLoader instance
self.loader = StorageKeyLoader(
self.session,
self.tenant_id,
access_controller=DatabaseFileAccessController(),
)
def tearDown(self):
"""Clean up test data after each test method."""
self.session.rollback()
# ------------------------------------------------------------------
# Per-test helpers (use db_session_with_containers as parameter)
# ------------------------------------------------------------------
@staticmethod
def _create_upload_file(
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
session: Session,
tenant_id: str,
user_id: str,
*,
file_id: str | None = None,
storage_key: str | None = None,
override_tenant_id: str | None = None,
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if storage_key is None:
storage_key = f"test_storage_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
"""Create and flush an UploadFile record for testing."""
upload_file = UploadFile(
tenant_id=tenant_id,
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
storage_type=StorageType.LOCAL,
key=storage_key,
key=storage_key or f"test_storage_key_{uuid4()}",
name="test_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_by=user_id,
created_at=datetime.now(UTC),
used=False,
)
upload_file.id = file_id
self.session.add(upload_file)
self.session.flush()
self.test_upload_files.append(upload_file)
upload_file.id = file_id or str(uuid4())
session.add(upload_file)
session.flush()
return upload_file
@staticmethod
def _create_tool_file(
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
session: Session,
tenant_id: str,
user_id: str,
conversation_id: str,
*,
file_id: str | None = None,
file_key: str | None = None,
override_tenant_id: str | None = None,
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
file_id = str(uuid4())
if file_key is None:
file_key = f"test_file_key_{uuid4()}"
if tenant_id is None:
tenant_id = self.tenant_id
"""Create and flush a ToolFile record for testing."""
tool_file = ToolFile(
user_id=self.user_id,
tenant_id=tenant_id,
conversation_id=self.conversation_id,
file_key=file_key,
user_id=user_id,
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
conversation_id=conversation_id,
file_key=file_key or f"test_file_key_{uuid4()}",
mimetype="text/plain",
original_url="http://example.com/file.txt",
name="test_tool_file.txt",
size=2048,
)
tool_file.id = file_id
self.session.add(tool_file)
self.session.flush()
self.test_tool_files.append(tool_file)
tool_file.id = file_id or str(uuid4())
session.add(tool_file)
session.flush()
return tool_file
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
file_related_id = None
remote_url = None
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
file_related_id = related_id
elif transfer_method == FileTransferMethod.REMOTE_URL:
remote_url = "https://example.com/test_file.txt"
file_related_id = related_id
@staticmethod
def _create_file(
tenant_id: str,
related_id: str,
transfer_method: FileTransferMethod,
*,
override_tenant_id: str | None = None,
) -> File:
"""Build a File value-object for testing."""
remote_url = "https://example.com/test_file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
return File(
file_id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
file_id=str(uuid4()),
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
file_type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
related_id=related_id,
remote_url=remote_url,
filename="test_file.txt",
extension=".txt",
@ -136,240 +109,280 @@ class TestStorageKeyLoader(unittest.TestCase):
storage_key="initial_key",
)
def test_load_storage_keys_local_file(self):
# ------------------------------------------------------------------
# Tests
# ------------------------------------------------------------------
def test_load_storage_keys_local_file(self, db_session_with_containers: Session):
"""Test loading storage keys for LOCAL_FILE transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
tenant_id = str(uuid4())
user_id = str(uuid4())
# Load storage keys
self.loader.load_storage_keys([file])
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_remote_url(self):
def test_load_storage_keys_remote_url(self, db_session_with_containers: Session):
"""Test loading storage keys for REMOTE_URL transfer method."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
tenant_id = str(uuid4())
user_id = str(uuid4())
# Load storage keys
self.loader.load_storage_keys([file])
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == upload_file.key
def test_load_storage_keys_tool_file(self):
def test_load_storage_keys_tool_file(self, db_session_with_containers: Session):
"""Test loading storage keys for TOOL_FILE transfer method."""
# Create test data
tool_file = self._create_tool_file()
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
tenant_id = str(uuid4())
user_id = str(uuid4())
conversation_id = str(uuid4())
# Load storage keys
self.loader.load_storage_keys([file])
tool_file = self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id)
file = self._create_file(tenant_id, related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
loader.load_storage_keys([file])
# Verify storage key was loaded correctly
assert file._storage_key == tool_file.file_key
def test_load_storage_keys_mixed_methods(self):
def test_load_storage_keys_mixed_methods(self, db_session_with_containers: Session):
"""Test batch loading with mixed transfer methods."""
# Create test data for different transfer methods
upload_file1 = self._create_upload_file()
upload_file2 = self._create_upload_file()
tool_file = self._create_tool_file()
tenant_id = str(uuid4())
user_id = str(uuid4())
conversation_id = str(uuid4())
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
upload_file1 = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
upload_file2 = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
tool_file = self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id)
files = [file1, file2, file3]
file1 = self._create_file(tenant_id, related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(tenant_id, related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
file3 = self._create_file(tenant_id, related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
# Load storage keys
self.loader.load_storage_keys(files)
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
loader.load_storage_keys([file1, file2, file3])
# Verify all storage keys were loaded correctly
assert file1._storage_key == upload_file1.key
assert file2._storage_key == upload_file2.key
assert file3._storage_key == tool_file.file_key
def test_load_storage_keys_empty_list(self):
"""Test with empty file list."""
# Should not raise any exceptions
self.loader.load_storage_keys([])
def test_load_storage_keys_empty_list(self, db_session_with_containers: Session):
"""Test with empty file list — should not raise."""
tenant_id = str(uuid4())
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
loader.load_storage_keys([])
def test_load_storage_keys_ignores_legacy_file_tenant_id(self):
def test_load_storage_keys_ignores_legacy_file_tenant_id(self, db_session_with_containers: Session):
"""Legacy file tenant_id should not override the loader tenant scope."""
upload_file = self._create_upload_file()
tenant_id = str(uuid4())
user_id = str(uuid4())
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file = self._create_file(
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
tenant_id,
related_id=upload_file.id,
transfer_method=FileTransferMethod.LOCAL_FILE,
override_tenant_id=str(uuid4()),
)
self.loader.load_storage_keys([file])
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
loader.load_storage_keys([file])
assert file._storage_key == upload_file.key
def test_load_storage_keys_missing_file_id(self):
"""Test with None file.related_id."""
# Create a file with valid parameters first, then manually set related_id to None
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
def test_load_storage_keys_missing_file_id(self, db_session_with_containers: Session):
"""Test with None file.related_id — should raise ValueError."""
tenant_id = str(uuid4())
user_id = str(uuid4())
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = None
# Should raise ValueError for None file related_id
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file])
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
with pytest.raises(ValueError, match="file id should not be None."):
loader.load_storage_keys([file])
assert str(context.value) == "file id should not be None."
def test_load_storage_keys_nonexistent_upload_file_records(self, db_session_with_containers: Session):
"""Test with missing UploadFile database records — should raise ValueError."""
tenant_id = str(uuid4())
file = self._create_file(tenant_id, related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
def test_load_storage_keys_nonexistent_upload_file_records(self):
"""Test with missing UploadFile database records."""
# Create file with non-existent upload file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Should raise ValueError for missing record
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
loader.load_storage_keys([file])
def test_load_storage_keys_nonexistent_tool_file_records(self):
"""Test with missing ToolFile database records."""
# Create file with non-existent tool file id
non_existent_id = str(uuid4())
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
def test_load_storage_keys_nonexistent_tool_file_records(self, db_session_with_containers: Session):
"""Test with missing ToolFile database records — should raise ValueError."""
tenant_id = str(uuid4())
file = self._create_file(tenant_id, related_id=str(uuid4()), transfer_method=FileTransferMethod.TOOL_FILE)
# Should raise ValueError for missing record
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
loader.load_storage_keys([file])
def test_load_storage_keys_invalid_uuid(self):
"""Test with invalid UUID format."""
# Create a file with valid parameters first, then manually set invalid related_id
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
def test_load_storage_keys_invalid_uuid(self, db_session_with_containers: Session):
"""Test with invalid UUID format — should raise ValueError."""
tenant_id = str(uuid4())
user_id = str(uuid4())
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file.related_id = "invalid-uuid-format"
# Should raise ValueError for invalid UUID
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
with pytest.raises(ValueError):
self.loader.load_storage_keys([file])
loader.load_storage_keys([file])
def test_load_storage_keys_batch_efficiency(self):
"""Test batched operations use efficient queries."""
# Create multiple files of different types
upload_files = [self._create_upload_file() for _ in range(3)]
tool_files = [self._create_tool_file() for _ in range(2)]
def test_load_storage_keys_batch_efficiency(self, db_session_with_containers: Session):
"""Batched operations should issue exactly 2 queries for mixed file types."""
tenant_id = str(uuid4())
user_id = str(uuid4())
conversation_id = str(uuid4())
files = []
files.extend(
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
upload_files = [self._create_upload_file(db_session_with_containers, tenant_id, user_id) for _ in range(3)]
tool_files = [
self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id) for _ in range(2)
]
files = [
self._create_file(tenant_id, related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE)
for uf in upload_files
] + [
self._create_file(tenant_id, related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE)
for tf in tool_files
]
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
files.extend(
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
)
# Mock the session to count queries
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
self.loader.load_storage_keys(files)
# Should make exactly 2 queries (one for upload_files, one for tool_files)
with patch.object(
db_session_with_containers, "scalars", wraps=db_session_with_containers.scalars
) as mock_scalars:
loader.load_storage_keys(files)
# Exactly 2 DB round-trips: one for UploadFile, one for ToolFile
assert mock_scalars.call_count == 2
# Verify all storage keys were loaded correctly
for i, file in enumerate(files[:3]):
assert file._storage_key == upload_files[i].key
for i, file in enumerate(files[3:]):
assert file._storage_key == tool_files[i].file_key
def test_load_storage_keys_tenant_isolation(self):
"""Test that tenant isolation works correctly."""
# Create files for different tenants
def test_load_storage_keys_tenant_isolation(self, db_session_with_containers: Session):
"""Loader should not surface records belonging to a different tenant."""
tenant_id = str(uuid4())
other_tenant_id = str(uuid4())
user_id = str(uuid4())
# Create upload file for current tenant
upload_file_current = self._create_upload_file()
upload_file_current = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
tenant_id, related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create upload file for other tenant (but don't add to cleanup list)
upload_file_other = UploadFile(
tenant_id=other_tenant_id,
storage_type=StorageType.LOCAL,
key="other_tenant_key",
name="other_file.txt",
size=1024,
extension=".txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
created_at=datetime.now(UTC),
used=False,
upload_file_other = self._create_upload_file(
db_session_with_containers,
tenant_id,
user_id,
override_tenant_id=other_tenant_id,
)
upload_file_other.id = str(uuid4())
self.session.add(upload_file_other)
self.session.flush()
# Create file for other tenant but try to load with current tenant's loader
file_other = self._create_file(
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
tenant_id,
related_id=upload_file_other.id,
transfer_method=FileTransferMethod.LOCAL_FILE,
override_tenant_id=other_tenant_id,
)
# Should raise ValueError due to tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_other])
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
assert "Upload file not found for id:" in str(context.value)
with pytest.raises(ValueError, match="Upload file not found for id:"):
loader.load_storage_keys([file_other])
# Current tenant's file should still work
self.loader.load_storage_keys([file_current])
# Current-tenant file still resolves correctly
loader.load_storage_keys([file_current])
assert file_current._storage_key == upload_file_current.key
def test_load_storage_keys_mixed_tenant_batch(self):
"""Test batch with mixed tenant files (should fail on first mismatch)."""
# Create files for current tenant
upload_file_current = self._create_upload_file()
def test_load_storage_keys_mixed_tenant_batch(self, db_session_with_containers: Session):
"""A batch containing a foreign-tenant file should fail on the mismatch."""
tenant_id = str(uuid4())
user_id = str(uuid4())
upload_file_current = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file_current = self._create_file(
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
tenant_id, related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
)
# Create file for different tenant
other_tenant_id = str(uuid4())
file_other = self._create_file(
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
tenant_id,
related_id=str(uuid4()),
transfer_method=FileTransferMethod.LOCAL_FILE,
override_tenant_id=str(uuid4()),
)
# Should raise ValueError on tenant mismatch
with pytest.raises(ValueError) as context:
self.loader.load_storage_keys([file_current, file_other])
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
with pytest.raises(ValueError, match="Upload file not found for id:"):
loader.load_storage_keys([file_current, file_other])
assert "Upload file not found for id:" in str(context.value)
def test_load_storage_keys_duplicate_file_ids(self, db_session_with_containers: Session):
"""Duplicate file IDs in the batch should be handled gracefully."""
tenant_id = str(uuid4())
user_id = str(uuid4())
def test_load_storage_keys_duplicate_file_ids(self):
"""Test handling of duplicate file IDs in the batch."""
# Create upload file
upload_file = self._create_upload_file()
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file1 = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# Create two File objects with same related_id
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
loader = StorageKeyLoader(
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
)
loader.load_storage_keys([file1, file2])
# Should handle duplicates gracefully
self.loader.load_storage_keys([file1, file2])
# Both files should have the same storage key
assert file1._storage_key == upload_file.key
assert file2._storage_key == upload_file.key
def test_load_storage_keys_session_isolation(self):
"""Test that the loader uses the provided session correctly."""
# Create test data
upload_file = self._create_upload_file()
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
def test_load_storage_keys_session_isolation(self, db_session_with_containers: Session):
"""A loader backed by an uncommitted session should not see data from another session."""
tenant_id = str(uuid4())
user_id = str(uuid4())
# Create loader with different session (same underlying connection)
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
# A loader with a fresh, separate session cannot see uncommitted rows from db_session_with_containers
with Session(bind=db.engine) as other_session:
other_loader = StorageKeyLoader(
other_session,
self.tenant_id,
tenant_id,
access_controller=DatabaseFileAccessController(),
)
with pytest.raises(ValueError):

View File

@ -8,6 +8,7 @@ Covers real Redis 7+ sharded pub/sub interactions including:
- Resource cleanup accounting via PUBSUB SHARDNUMSUB
"""
import socket
import threading
import time
import uuid
@ -356,10 +357,17 @@ class TestShardedRedisBroadcastChannelClusterIntegration:
def _get_test_topic_name(cls) -> str:
return f"test_sharded_cluster_topic_{uuid.uuid4()}"
@staticmethod
def _resolve_announced_ip(host: str) -> str:
"""Resolve the container host name to a literal IP accepted by Redis cluster config."""
return socket.getaddrinfo(host, None, type=socket.SOCK_STREAM)[0][4][0]
@staticmethod
def _ensure_single_node_cluster(host: str, port: int) -> None:
"""Bootstrap a single-node cluster using a literal IP for Redis node advertisement."""
client = redis.Redis(host=host, port=port, decode_responses=False)
client.config_set("cluster-announce-ip", host)
announced_ip = TestShardedRedisBroadcastChannelClusterIntegration._resolve_announced_ip(host)
client.config_set("cluster-announce-ip", announced_ip)
client.config_set("cluster-announce-port", port)
slots = client.execute_command("CLUSTER", "SLOTS")
if not slots:

View File

@ -5,6 +5,7 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_service import ApiKeyAuthService
@ -31,7 +32,7 @@ class TestApiKeyAuthService:
def mock_args(self, category, provider, mock_credentials) -> dict:
return {"category": category, "provider": provider, "credentials": mock_credentials}
def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False):
def _create_binding(self, db_session: Session, *, tenant_id, category, provider, credentials=None, disabled=False):
binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id,
category=category,
@ -44,7 +45,7 @@ class TestApiKeyAuthService:
return binding
def test_get_provider_auth_list_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider)
db_session_with_containers.expire_all()
@ -56,14 +57,16 @@ class TestApiKeyAuthService:
assert len(tenant_results) == 1
assert tenant_results[0].provider == provider
def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id):
def test_get_provider_auth_list_empty(
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id
):
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
tenant_results = [r for r in result if r.tenant_id == tenant_id]
assert tenant_results == []
def test_get_provider_auth_list_filters_disabled(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
self._create_binding(
db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True
@ -78,7 +81,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_success(
self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
self,
mock_encrypter,
mock_factory,
flask_app_with_containers,
db_session_with_containers: Session,
tenant_id,
mock_args,
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
@ -97,7 +106,7 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_validation_failed(
self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
self, mock_factory, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_args
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = False
@ -112,7 +121,13 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
def test_create_provider_auth_encrypts_api_key(
self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
self,
mock_encrypter,
mock_factory,
flask_app_with_containers,
db_session_with_containers: Session,
tenant_id,
mock_args,
):
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
@ -128,7 +143,13 @@ class TestApiKeyAuthService:
mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key)
def test_get_auth_credentials_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials
self,
flask_app_with_containers,
db_session_with_containers: Session,
tenant_id,
category,
provider,
mock_credentials,
):
self._create_binding(
db_session_with_containers,
@ -144,14 +165,14 @@ class TestApiKeyAuthService:
assert result == mock_credentials
def test_get_auth_credentials_not_found(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
assert result is None
def test_get_auth_credentials_json_parsing(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
self._create_binding(
@ -169,7 +190,7 @@ class TestApiKeyAuthService:
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
def test_delete_provider_auth_success(
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
):
binding = self._create_binding(
db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider
@ -183,7 +204,9 @@ class TestApiKeyAuthService:
remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first()
assert remaining is None
def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id):
def test_delete_provider_auth_not_found(
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id
):
# Should not raise when binding not found
ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4()))

View File

@ -10,6 +10,7 @@ from uuid import uuid4
import httpx
import pytest
from sqlalchemy.orm import Session
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
@ -114,7 +115,7 @@ class TestAuthIntegration:
assert result2[0].tenant_id == tenant_id_2
def test_cross_tenant_access_prevention(
self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id_2, category
):
result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL)

View File

@ -12,6 +12,7 @@ from unittest.mock import create_autospec, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.storage.storage_type import StorageType
@ -273,7 +274,9 @@ class TestDocumentServicePauseDocument:
"user_id": user_id,
}
def test_pause_document_waiting_state_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_pause_document_waiting_state_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful pause of document in waiting state.
@ -310,7 +313,7 @@ class TestDocumentServicePauseDocument:
mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True")
def test_pause_document_indexing_state_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful pause of document in indexing state.
@ -340,7 +343,9 @@ class TestDocumentServicePauseDocument:
assert document.is_paused is True
assert document.paused_by == mock_document_service_dependencies["user_id"]
def test_pause_document_parsing_state_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_pause_document_parsing_state_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful pause of document in parsing state.
@ -367,7 +372,9 @@ class TestDocumentServicePauseDocument:
db_session_with_containers.refresh(document)
assert document.is_paused is True
def test_pause_document_completed_state_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_pause_document_completed_state_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when trying to pause completed document.
@ -396,7 +403,9 @@ class TestDocumentServicePauseDocument:
db_session_with_containers.refresh(document)
assert document.is_paused is False
def test_pause_document_error_state_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_pause_document_error_state_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when trying to pause document in error state.
@ -467,7 +476,9 @@ class TestDocumentServiceRecoverDocument:
"recover_task": mock_task,
}
def test_recover_document_paused_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_recover_document_paused_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful recovery of paused document.
@ -510,7 +521,9 @@ class TestDocumentServiceRecoverDocument:
document.dataset_id, document.id
)
def test_recover_document_not_paused_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_recover_document_not_paused_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when trying to recover non-paused document.
@ -590,7 +603,9 @@ class TestDocumentServiceRetryDocument:
"user_id": user_id,
}
def test_retry_document_single_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_retry_document_single_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful retry of single document.
@ -629,7 +644,9 @@ class TestDocumentServiceRetryDocument:
dataset.id, [document.id], mock_document_service_dependencies["user_id"]
)
def test_retry_document_multiple_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_retry_document_multiple_success(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful retry of multiple documents.
@ -675,7 +692,7 @@ class TestDocumentServiceRetryDocument:
)
def test_retry_document_concurrent_retry_error(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when document is already being retried.
@ -708,7 +725,7 @@ class TestDocumentServiceRetryDocument:
assert document.indexing_status == IndexingStatus.ERROR
def test_retry_document_missing_current_user_error(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when current_user is missing.
@ -794,7 +811,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
}
def test_batch_update_document_status_enable_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful batch enabling of documents.
@ -844,7 +861,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
assert mock_document_service_dependencies["add_task"].delay.call_count == 2
def test_batch_update_document_status_disable_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful batch disabling of documents.
@ -886,7 +903,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id)
def test_batch_update_document_status_archive_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful batch archiving of documents.
@ -928,7 +945,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id)
def test_batch_update_document_status_unarchive_success(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test successful batch unarchiving of documents.
@ -970,7 +987,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id)
def test_batch_update_document_status_empty_list(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test handling of empty document list.
@ -996,7 +1013,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
mock_document_service_dependencies["remove_task"].delay.assert_not_called()
def test_batch_update_document_status_document_indexing_error(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when document is being indexed.
@ -1073,7 +1090,7 @@ class TestDocumentServiceRenameDocument:
"current_user": mock_current_user,
}
def test_rename_document_success(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_success(self, db_session_with_containers: Session, mock_document_service_dependencies):
"""
Test successful document renaming.
@ -1111,7 +1128,9 @@ class TestDocumentServiceRenameDocument:
assert result == document
assert document.name == new_name
def test_rename_document_with_built_in_fields(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_with_built_in_fields(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test document renaming with built-in fields enabled.
@ -1154,7 +1173,9 @@ class TestDocumentServiceRenameDocument:
assert document.doc_metadata["document_name"] == new_name
assert document.doc_metadata["existing_key"] == "existing_value"
def test_rename_document_with_upload_file(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_with_upload_file(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test document renaming with associated upload file.
@ -1202,7 +1223,7 @@ class TestDocumentServiceRenameDocument:
assert upload_file.name == new_name
def test_rename_document_dataset_not_found_error(
self, db_session_with_containers, mock_document_service_dependencies
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when dataset is not found.
@ -1224,7 +1245,9 @@ class TestDocumentServiceRenameDocument:
with pytest.raises(ValueError, match="Dataset not found"):
DocumentService.rename_document(dataset_id, document_id, new_name)
def test_rename_document_not_found_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_not_found_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when document is not found.
@ -1251,7 +1274,9 @@ class TestDocumentServiceRenameDocument:
with pytest.raises(ValueError, match="Document not found"):
DocumentService.rename_document(dataset.id, document_id, new_name)
def test_rename_document_permission_error(self, db_session_with_containers, mock_document_service_dependencies):
def test_rename_document_permission_error(
self, db_session_with_containers: Session, mock_document_service_dependencies
):
"""
Test error when user lacks permission.

View File

@ -11,6 +11,7 @@ from uuid import uuid4
import pytest
from redis import RedisError
from sqlalchemy.orm import Session
from extensions.ext_redis import redis_client
from models.account import TenantAccountJoin
@ -122,7 +123,7 @@ class TestSyncAccountDeletion:
mock_queue_task.assert_not_called()
def test_sync_account_deletion_multiple_workspaces(
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
):
account_id = str(uuid4())
tenant_ids = [str(uuid4()) for _ in range(3)]
@ -144,7 +145,7 @@ class TestSyncAccountDeletion:
assert queued_workspace_ids == set(tenant_ids)
def test_sync_account_deletion_no_workspaces(
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
):
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = True
@ -155,7 +156,7 @@ class TestSyncAccountDeletion:
mock_queue_task.assert_not_called()
def test_sync_account_deletion_partial_failure(
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
):
account_id = str(uuid4())
tenant_ids = [str(uuid4()) for _ in range(3)]
@ -180,7 +181,7 @@ class TestSyncAccountDeletion:
assert mock_queue_task.call_count == 3
def test_sync_account_deletion_all_failures(
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
):
account_id = str(uuid4())
tenant_id = str(uuid4())

View File

@ -0,0 +1,94 @@
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
def _tenant_id() -> str:
return str(uuid4())
def _get_permission(session: Session, tenant_id: str) -> TenantPluginPermission | None:
session.expire_all()
stmt = select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)
return session.scalars(stmt).one_or_none()
def _count_permissions(session: Session, tenant_id: str) -> int:
stmt = select(func.count()).select_from(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)
return session.scalar(stmt) or 0
class TestGetPermission:
"""Integration tests for PluginPermissionService.get_permission using testcontainers."""
def test_returns_permission_when_found(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
permission = TenantPluginPermission(
tenant_id=tenant_id,
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
db_session_with_containers.add(permission)
db_session_with_containers.commit()
result = PluginPermissionService.get_permission(tenant_id)
assert result is not None
assert result.id == permission.id
assert result.tenant_id == tenant_id
assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_returns_none_when_not_found(self, db_session_with_containers: Session):
result = PluginPermissionService.get_permission(_tenant_id())
assert result is None
class TestChangePermission:
"""Integration tests for PluginPermissionService.change_permission using testcontainers."""
def test_creates_new_permission_when_not_exists(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
result = PluginPermissionService.change_permission(
tenant_id,
TenantPluginPermission.InstallPermission.EVERYONE,
TenantPluginPermission.DebugPermission.EVERYONE,
)
permission = _get_permission(db_session_with_containers, tenant_id)
assert result is True
assert permission is not None
assert permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
assert permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_updates_existing_permission(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
existing = TenantPluginPermission(
tenant_id=tenant_id,
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
db_session_with_containers.add(existing)
db_session_with_containers.commit()
result = PluginPermissionService.change_permission(
tenant_id,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.ADMINS,
)
permission = _get_permission(db_session_with_containers, tenant_id)
assert result is True
assert permission is not None
assert permission.id == existing.id
assert permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
assert _count_permissions(db_session_with_containers, tenant_id) == 1

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy.orm import Session
from models.model import App, RecommendedApp, Site
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
from services.recommend_app.recommend_app_type import RecommendAppType
@ -91,7 +93,7 @@ class TestDatabaseRecommendAppRetrieval:
class TestFetchRecommendedAppsFromDb:
def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers):
def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
_create_site(db_session_with_containers, app_id=app1.id)
@ -111,7 +113,9 @@ class TestFetchRecommendedAppsFromDb:
assert "assistant" in result["categories"]
assert "writing" in result["categories"]
def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers):
def test_falls_back_to_default_language_when_empty(
self, flask_app_with_containers, db_session_with_containers: Session
):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
_create_site(db_session_with_containers, app_id=app1.id)
@ -124,7 +128,7 @@ class TestFetchRecommendedAppsFromDb:
app_ids = {r["app_id"] for r in result["recommended_apps"]}
assert app1.id in app_ids
def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers):
def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False)
_create_site(db_session_with_containers, app_id=app1.id)
@ -137,7 +141,7 @@ class TestFetchRecommendedAppsFromDb:
app_ids = {r["app_id"] for r in result["recommended_apps"]}
assert app1.id not in app_ids
def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers):
def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
_create_recommended_app(db_session_with_containers, app_id=app1.id)
@ -151,12 +155,12 @@ class TestFetchRecommendedAppsFromDb:
class TestFetchRecommendedAppDetailFromDb:
def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers):
def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers: Session):
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4()))
assert result is None
def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers):
def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False)
_create_recommended_app(db_session_with_containers, app_id=app1.id)
@ -168,7 +172,7 @@ class TestFetchRecommendedAppDetailFromDb:
assert result is None
@patch("services.recommend_app.database.database_retrieval.AppDslService")
def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers):
def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
_create_site(db_session_with_containers, app_id=app1.id)

View File

@ -2,6 +2,7 @@ import copy
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
@ -29,7 +30,9 @@ class TestAdvancedPromptTemplateService:
# for consistency with other test files
return {}
def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_prompt_baichuan_model_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful prompt generation for Baichuan model.
@ -64,7 +67,9 @@ class TestAdvancedPromptTemplateService:
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_prompt_common_model_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful prompt generation for common models.
@ -100,7 +105,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#query#}}" in prompt_text
def test_get_prompt_case_insensitive_baichuan_detection(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan model detection is case insensitive.
@ -131,7 +136,7 @@ class TestAdvancedPromptTemplateService:
assert BAICHUAN_CONTEXT in prompt_text
def test_get_common_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation for chat app with completion mode.
@ -161,7 +166,9 @@ class TestAdvancedPromptTemplateService:
assert "{{#histories#}}" in prompt_text
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_common_prompt_chat_app_chat_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation for chat app with chat mode.
@ -189,7 +196,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with completion mode.
@ -217,7 +224,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation for completion app with chat mode.
@ -245,7 +252,9 @@ class TestAdvancedPromptTemplateService:
assert CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_common_prompt_no_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation without context.
@ -273,7 +282,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#query#}}" in prompt_text
def test_get_common_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported app mode.
@ -291,7 +300,7 @@ class TestAdvancedPromptTemplateService:
assert result == {}
def test_get_common_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test common prompt generation with unsupported model mode.
@ -308,7 +317,9 @@ class TestAdvancedPromptTemplateService:
# Assert: Verify empty dict is returned
assert result == {}
def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_completion_prompt_with_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test completion prompt generation with context.
@ -339,7 +350,7 @@ class TestAdvancedPromptTemplateService:
assert result_text == CONTEXT + original_text
def test_get_completion_prompt_without_context(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test completion prompt generation without context.
@ -368,7 +379,9 @@ class TestAdvancedPromptTemplateService:
assert result_text == original_text
assert CONTEXT not in result_text
def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_chat_prompt_with_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test chat prompt generation with context.
@ -399,7 +412,9 @@ class TestAdvancedPromptTemplateService:
assert original_text in result_text
assert result_text == CONTEXT + original_text
def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_chat_prompt_without_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test chat prompt generation without context.
@ -429,7 +444,7 @@ class TestAdvancedPromptTemplateService:
assert CONTEXT not in result_text
def test_get_baichuan_prompt_chat_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with completion mode.
@ -460,7 +475,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_chat_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for chat app with chat mode.
@ -489,7 +504,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_completion_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with completion mode.
@ -517,7 +532,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_completion_app_chat_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation for completion app with chat mode.
@ -545,7 +560,9 @@ class TestAdvancedPromptTemplateService:
assert BAICHUAN_CONTEXT in prompt_text
assert "{{#pre_prompt#}}" in prompt_text
def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_baichuan_prompt_no_context(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation without context.
@ -573,7 +590,7 @@ class TestAdvancedPromptTemplateService:
assert "{{#query#}}" in prompt_text
def test_get_baichuan_prompt_unsupported_app_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported app mode.
@ -591,7 +608,7 @@ class TestAdvancedPromptTemplateService:
assert result == {}
def test_get_baichuan_prompt_unsupported_model_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test Baichuan prompt generation with unsupported model mode.
@ -609,7 +626,7 @@ class TestAdvancedPromptTemplateService:
assert result == {}
def test_get_prompt_all_app_modes_common_model(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with common model.
@ -641,7 +658,7 @@ class TestAdvancedPromptTemplateService:
assert result != {}
def test_get_prompt_all_app_modes_baichuan_model(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test prompt generation for all app modes with Baichuan model.
@ -672,7 +689,7 @@ class TestAdvancedPromptTemplateService:
assert result is not None
assert result != {}
def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_prompt_edge_cases(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test prompt generation with edge cases.
@ -704,7 +721,7 @@ class TestAdvancedPromptTemplateService:
# Should either return a valid result or empty dict, but not crash
assert result is not None
def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
def test_template_immutability(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test that original templates are not modified.
@ -738,7 +755,9 @@ class TestAdvancedPromptTemplateService:
assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
def test_baichuan_template_immutability(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that original Baichuan templates are not modified.
@ -772,7 +791,9 @@ class TestAdvancedPromptTemplateService:
assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies):
def test_context_integration_consistency(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test consistency of context integration across different scenarios.
@ -828,7 +849,7 @@ class TestAdvancedPromptTemplateService:
assert prompt_text.startswith(CONTEXT)
def test_baichuan_context_integration_consistency(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test consistency of Baichuan context integration across different scenarios.

View File

@ -3,12 +3,15 @@ from __future__ import annotations
import base64
import json
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
import yaml
from faker import Faker
from flask import Flask
from sqlalchemy.orm import Session
from core.trigger.constants import (
TRIGGER_PLUGIN_NODE_TYPE,
@ -17,7 +20,7 @@ from core.trigger.constants import (
)
from extensions.ext_redis import redis_client
from graphon.enums import BuiltinNodeTypes
from models import Account, AppMode
from models import Account, App, AppMode
from models.model import AppModelConfig, IconType
from services import app_dsl_service
from services.account_service import AccountService, TenantService
@ -67,11 +70,27 @@ def _pending_yaml_content(version: str = "99.0.0") -> bytes:
return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode()
def _app_stub(**overrides: Any) -> App:
defaults = {
"id": str(uuid4()),
"tenant_id": _DEFAULT_TENANT_ID,
"mode": AppMode.WORKFLOW.value,
"name": "n",
"description": "d",
"icon_type": IconType.EMOJI,
"icon": "i",
"icon_background": "#fff",
"use_icon_as_answer_icon": False,
"app_model_config": None,
}
return cast(App, SimpleNamespace(**(defaults | overrides)))
class TestAppDslService:
"""Integration tests for AppDslService using testcontainers."""
@pytest.fixture
def app(self, flask_app_with_containers):
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
@pytest.fixture
@ -112,7 +131,7 @@ class TestAppDslService:
"enterprise_service": mock_enterprise_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
fake = Faker()
with patch("services.account_service.FeatureService") as mock_account_feature_service:
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
@ -189,7 +208,7 @@ class TestAppDslService:
# ── Import: Validation ────────────────────────────────────────────
def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers):
def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Invalid import_mode"):
service.import_app(
@ -198,7 +217,7 @@ class TestAppDslService:
yaml_content="version: '0.1.0'",
)
def test_import_app_missing_yaml_content(self, db_session_with_containers):
def test_import_app_missing_yaml_content(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -208,7 +227,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "yaml_content is required" in result.error
def test_import_app_missing_yaml_url(self, db_session_with_containers):
def test_import_app_missing_yaml_url(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -218,7 +237,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "yaml_url is required" in result.error
def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers):
def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -228,7 +247,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "content must be a mapping" in result.error
def test_import_app_version_not_str_returns_failed(self, db_session_with_containers):
def test_import_app_version_not_str_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}})
result = service.import_app(
@ -239,7 +258,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Invalid version type" in result.error
def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers):
def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -249,7 +268,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Missing app data" in result.error
def test_import_app_yaml_error_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
def bad_safe_load(_content: str):
raise yaml.YAMLError("bad")
@ -264,7 +283,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert result.error.startswith("Invalid YAML format:")
def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
monkeypatch.setattr(
AppDslService,
"_create_or_update_app",
@ -282,7 +301,7 @@ class TestAppDslService:
# ── Import: YAML URL ──────────────────────────────────────────────
def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
monkeypatch.setattr(
app_dsl_service.ssrf_proxy,
"get",
@ -298,7 +317,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Error fetching YAML from URL: boom" in result.error
def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch):
response = MagicMock()
response.content = b""
response.raise_for_status.return_value = None
@ -313,7 +332,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "Empty content" in result.error
def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch):
response = MagicMock()
response.content = b"x" * (DSL_MAX_SIZE + 1)
response.raise_for_status.return_value = None
@ -328,7 +347,9 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "File size exceeds" in result.error
def test_import_app_yaml_url_user_attachments_keeps_original_url(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_user_attachments_keeps_original_url(
self, db_session_with_containers: Session, monkeypatch
):
yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml"
yaml_bytes = _pending_yaml_content()
@ -354,7 +375,7 @@ class TestAppDslService:
assert result.imported_dsl_version == "99.0.0"
assert requested_urls == [yaml_url]
def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers, monkeypatch):
def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch):
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
yaml_bytes = _pending_yaml_content()
@ -383,7 +404,7 @@ class TestAppDslService:
# ── Import: App ID checks ────────────────────────────────────────
def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers):
def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -395,7 +416,7 @@ class TestAppDslService:
assert result.error == "App not found"
def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
assert app.mode == "chat"
@ -412,7 +433,7 @@ class TestAppDslService:
# ── Import: Flow ──────────────────────────────────────────────────
def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers):
def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.import_app(
account=_account_mock(),
@ -432,7 +453,7 @@ class TestAppDslService:
assert stored is not None
def test_import_app_completed_uses_declared_dependencies(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
_, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
@ -466,7 +487,7 @@ class TestAppDslService:
@pytest.mark.parametrize("has_workflow", [True, False])
def test_import_app_legacy_versions_extract_dependencies(
self, db_session_with_containers, monkeypatch, has_workflow: bool
self, db_session_with_containers: Session, monkeypatch, has_workflow: bool
):
monkeypatch.setattr(
AppDslService,
@ -523,13 +544,13 @@ class TestAppDslService:
# ── Confirm Import ────────────────────────────────────────────────
def test_confirm_import_expired_returns_failed(self, db_session_with_containers):
def test_confirm_import_expired_returns_failed(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
result = service.confirm_import(import_id=str(uuid4()), account=_account_mock())
assert result.status == ImportStatus.FAILED
assert "expired" in result.error
def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers, monkeypatch):
def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch):
import_id = str(uuid4())
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
@ -562,7 +583,7 @@ class TestAppDslService:
assert result.app_id == created_app.id
assert redis_client.get(redis_key) is None
def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers):
def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers: Session):
import_id = str(uuid4())
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123")
@ -572,7 +593,7 @@ class TestAppDslService:
assert result.status == ImportStatus.FAILED
assert "validation error" in result.error
def test_confirm_import_exception_returns_failed(self, db_session_with_containers):
def test_confirm_import_exception_returns_failed(self, db_session_with_containers: Session):
import_id = str(uuid4())
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json")
@ -583,13 +604,13 @@ class TestAppDslService:
# ── Check Dependencies ────────────────────────────────────────────
def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers):
def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
app_model = SimpleNamespace(id=str(uuid4()), tenant_id=_DEFAULT_TENANT_ID)
app_model = _app_stub()
result = service.check_dependencies(app_model=app_model)
assert result.leaked_dependencies == []
def test_check_dependencies_calls_analysis_service(self, db_session_with_containers, monkeypatch):
def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch):
app_id = str(uuid4())
pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id)
redis_client.setex(
@ -614,10 +635,12 @@ class TestAppDslService:
)
service = AppDslService(db_session_with_containers)
result = service.check_dependencies(app_model=SimpleNamespace(id=app_id, tenant_id=_DEFAULT_TENANT_ID))
result = service.check_dependencies(app_model=_app_stub(id=app_id))
assert len(result.leaked_dependencies) == 1
def test_check_dependencies_with_real_app(self, db_session_with_containers, mock_external_service_dependencies):
def test_check_dependencies_with_real_app(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}'
@ -633,12 +656,12 @@ class TestAppDslService:
# ── Create/Update App ─────────────────────────────────────────────
def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers):
def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="loss app mode"):
service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock())
def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers, monkeypatch):
def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch):
fixed_now = object()
monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now)
@ -656,9 +679,7 @@ class TestAppDslService:
lambda _m: SimpleNamespace(kind="conv"),
)
app = SimpleNamespace(
id=str(uuid4()),
tenant_id=_DEFAULT_TENANT_ID,
app = _app_stub(
mode=AppMode.WORKFLOW.value,
name="old",
description="old-desc",
@ -667,7 +688,6 @@ class TestAppDslService:
icon_background="#111111",
updated_by=None,
updated_at=None,
app_model_config=None,
)
service = AppDslService(db_session_with_containers)
updated = service._create_or_update_app(
@ -693,7 +713,7 @@ class TestAppDslService:
assert app.icon_background == "#222222"
assert app.updated_at is fixed_now
def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers):
def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers: Session):
account = _account_mock()
account.current_tenant_id = None
service = AppDslService(db_session_with_containers)
@ -705,7 +725,7 @@ class TestAppDslService:
)
def test_create_or_update_app_creates_workflow_app_and_saves_dependencies(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
_, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
@ -741,42 +761,26 @@ class TestAppDslService:
stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}")
assert stored is not None
def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers):
def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Missing workflow data"):
service._create_or_update_app(
app=SimpleNamespace(
id=str(uuid4()),
tenant_id=_DEFAULT_TENANT_ID,
mode=AppMode.WORKFLOW.value,
name="n",
description="d",
icon_background="#fff",
app_model_config=None,
),
app=_app_stub(mode=AppMode.WORKFLOW.value),
data={"app": {"mode": AppMode.WORKFLOW.value}},
account=_account_mock(),
)
def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers):
def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Missing model_config"):
service._create_or_update_app(
app=SimpleNamespace(
id=str(uuid4()),
tenant_id=_DEFAULT_TENANT_ID,
mode=AppMode.CHAT.value,
name="n",
description="d",
icon_background="#fff",
app_model_config=None,
),
app=_app_stub(mode=AppMode.CHAT.value),
data={"app": {"mode": AppMode.CHAT.value}},
account=_account_mock(),
)
def test_create_or_update_app_chat_creates_model_config_and_sends_event(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
app.app_model_config_id = None
@ -795,19 +799,11 @@ class TestAppDslService:
db_session_with_containers.expire_all()
assert app.app_model_config_id is not None
def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers):
def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers: Session):
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Invalid app mode"):
service._create_or_update_app(
app=SimpleNamespace(
id=str(uuid4()),
tenant_id=_DEFAULT_TENANT_ID,
mode=AppMode.RAG_PIPELINE.value,
name="n",
description="d",
icon_background="#fff",
app_model_config=None,
),
app=_app_stub(mode=AppMode.RAG_PIPELINE.value),
data={"app": {"mode": AppMode.RAG_PIPELINE.value}},
account=_account_mock(),
)
@ -828,29 +824,16 @@ class TestAppDslService:
lambda *_args, **_kwargs: model_calls.append(True),
)
workflow_app = SimpleNamespace(
workflow_app = _app_stub(
mode=AppMode.WORKFLOW.value,
tenant_id=_DEFAULT_TENANT_ID,
name="n",
icon="i",
icon_type="emoji",
icon_background="#fff",
description="d",
use_icon_as_answer_icon=False,
app_model_config=None,
)
AppDslService.export_dsl(workflow_app)
assert workflow_calls == [True]
chat_app = SimpleNamespace(
chat_app = _app_stub(
mode=AppMode.CHAT.value,
tenant_id=_DEFAULT_TENANT_ID,
name="n",
icon="i",
icon_type="emoji",
icon_background="#fff",
description="d",
use_icon_as_answer_icon=False,
app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}),
)
AppDslService.export_dsl(chat_app)
@ -863,16 +846,14 @@ class TestAppDslService:
lambda **_kwargs: None,
)
emoji_app = SimpleNamespace(
emoji_app = _app_stub(
mode=AppMode.WORKFLOW.value,
tenant_id=_DEFAULT_TENANT_ID,
name="Emoji App",
icon="🎨",
icon_type=IconType.EMOJI,
icon_background="#FF5733",
description="App with emoji icon",
use_icon_as_answer_icon=True,
app_model_config=None,
)
yaml_output = AppDslService.export_dsl(emoji_app)
data = yaml.safe_load(yaml_output)
@ -880,16 +861,14 @@ class TestAppDslService:
assert data["app"]["icon_type"] == "emoji"
assert data["app"]["icon_background"] == "#FF5733"
image_app = SimpleNamespace(
image_app = _app_stub(
mode=AppMode.WORKFLOW.value,
tenant_id=_DEFAULT_TENANT_ID,
name="Image App",
icon="https://example.com/icon.png",
icon_type=IconType.IMAGE,
icon_background="#FFEAD5",
description="App with image icon",
use_icon_as_answer_icon=False,
app_model_config=None,
)
yaml_output = AppDslService.export_dsl(image_app)
data = yaml.safe_load(yaml_output)
@ -897,7 +876,7 @@ class TestAppDslService:
assert data["app"]["icon_type"] == "image"
assert data["app"]["icon_background"] == "#FFEAD5"
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_export_dsl_chat_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
model_config = AppModelConfig(
@ -935,7 +914,9 @@ class TestAppDslService:
assert "model_config" in exported_data
assert "dependencies" in exported_data
def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_export_dsl_workflow_app_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
app.mode = "workflow"
db_session_with_containers.commit()
@ -968,7 +949,9 @@ class TestAppDslService:
assert "workflow" in exported_data
assert "dependencies" in exported_data
def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_export_dsl_with_workflow_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
app.mode = "workflow"
db_session_with_containers.commit()
@ -1008,7 +991,7 @@ class TestAppDslService:
assert "workflow" in exported_data
def test_export_dsl_with_invalid_workflow_id_raises_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
app.mode = "workflow"
@ -1106,7 +1089,7 @@ class TestAppDslService:
export_data: dict = {}
AppDslService._append_workflow_export_data(
export_data=export_data,
app_model=SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID),
app_model=_app_stub(),
include_secret=False,
workflow_id=None,
)
@ -1132,7 +1115,7 @@ class TestAppDslService:
with pytest.raises(ValueError, match="Missing draft workflow configuration"):
AppDslService._append_workflow_export_data(
export_data={},
app_model=SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID),
app_model=_app_stub(),
include_secret=False,
workflow_id=None,
)
@ -1160,7 +1143,7 @@ class TestAppDslService:
monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x)
app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}})
app_model = SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID, app_model_config=app_model_config)
app_model = _app_stub(app_model_config=app_model_config)
export_data: dict = {}
AppDslService._append_model_config_export_data(export_data, app_model)
@ -1169,7 +1152,7 @@ class TestAppDslService:
def test_append_model_config_export_data_requires_app_config(self):
with pytest.raises(ValueError, match="Missing app configuration"):
AppDslService._append_model_config_export_data({}, SimpleNamespace(app_model_config=None))
AppDslService._append_model_config_export_data({}, _app_stub(app_model_config=None))
# ── Dependency Extraction ─────────────────────────────────────────

View File

@ -7,6 +7,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import App
from models.model import EndUser
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
@ -184,7 +185,7 @@ class TestAppGenerateService:
return app, account
def _create_test_workflow(self, db_session_with_containers: Session, app):
def _create_test_workflow(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test workflow for testing.

View File

@ -7,7 +7,7 @@ from uuid import uuid4
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import NotFound
import services.attachment_service as attachment_service_module
@ -19,7 +19,7 @@ from services.attachment_service import AttachmentService
class TestAttachmentService:
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
def _create_upload_file(self, db_session_with_containers: Session, *, tenant_id: str | None = None) -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id or str(uuid4()),
storage_type=StorageType.OPENDAL,
@ -60,7 +60,7 @@ class TestAttachmentService:
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
AttachmentService(session_factory=invalid_session_factory)
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
def test_should_return_base64_when_file_exists(self, db_session_with_containers: Session):
upload_file = self._create_upload_file(db_session_with_containers)
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
@ -70,7 +70,7 @@ class TestAttachmentService:
assert result == base64.b64encode(b"binary-content").decode()
mock_load.assert_called_once_with(upload_file.key)
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers: Session):
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
with patch.object(attachment_service_module.storage, "load_once") as mock_load:

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from extensions.ext_redis import redis_client
@ -24,7 +25,7 @@ class TestBillingServiceGetPlanBulkWithCache:
"""
@pytest.fixture(autouse=True)
def setup_redis_cleanup(self, flask_app_with_containers):
def setup_redis_cleanup(self, flask_app_with_containers: Flask):
"""Clean up Redis cache before and after each test."""
with flask_app_with_containers.app_context():
# Clean up before test
@ -56,7 +57,7 @@ class TestBillingServiceGetPlanBulkWithCache:
return value
return None
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers: Flask):
"""Test bulk plan retrieval when all tenants are in cache."""
with flask_app_with_containers.app_context():
# Arrange
@ -87,7 +88,7 @@ class TestBillingServiceGetPlanBulkWithCache:
# Verify API was not called
mock_get_plan_bulk.assert_not_called()
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers: Flask):
"""Test bulk plan retrieval when all tenants are not in cache."""
with flask_app_with_containers.app_context():
# Arrange
@ -127,7 +128,7 @@ class TestBillingServiceGetPlanBulkWithCache:
assert ttl_1 > 0
assert ttl_1 <= 600 # Should be <= 600 seconds
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers: Flask):
"""Test bulk plan retrieval when some tenants are in cache, some are not."""
with flask_app_with_containers.app_context():
# Arrange
@ -158,7 +159,7 @@ class TestBillingServiceGetPlanBulkWithCache:
cached_data_3 = json.loads(cached_3)
assert cached_data_3 == missing_plan["tenant-3"]
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers: Flask):
"""Test fallback to API when Redis mget fails."""
with flask_app_with_containers.app_context():
# Arrange
@ -189,7 +190,7 @@ class TestBillingServiceGetPlanBulkWithCache:
assert cached_1 is not None
assert cached_2 is not None
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers: Flask):
"""Test fallback to API when cache contains invalid JSON."""
with flask_app_with_containers.app_context():
# Arrange
@ -241,7 +242,7 @@ class TestBillingServiceGetPlanBulkWithCache:
cached_data_3 = json.loads(cached_3)
assert cached_data_3 == expected_plans["tenant-3"]
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers: Flask):
"""Test fallback to API when cache data doesn't match SubscriptionPlan schema."""
with flask_app_with_containers.app_context():
# Arrange
@ -274,7 +275,7 @@ class TestBillingServiceGetPlanBulkWithCache:
# Verify API was called for tenant-2 and tenant-3
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers: Flask):
"""Test that pipeline failure doesn't affect return value."""
with flask_app_with_containers.app_context():
# Arrange
@ -303,7 +304,7 @@ class TestBillingServiceGetPlanBulkWithCache:
# Verify pipeline was attempted
mock_pipeline.assert_called_once()
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers: Flask):
"""Test with empty tenant_ids list."""
with flask_app_with_containers.app_context():
# Act
@ -321,7 +322,7 @@ class TestBillingServiceGetPlanBulkWithCache:
# But we should check that mget was not called at all
# Since we can't easily verify this without more mocking, we just verify the result
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers):
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers: Flask):
"""Test that expired cache keys are treated as cache misses."""
with flask_app_with_containers.app_context():
# Arrange

View File

@ -7,6 +7,7 @@ from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account, Tenant, TenantAccountJoin
@ -170,7 +171,7 @@ class ConversationServiceIntegrationTestDataFactory:
class TestConversationServicePagination:
"""Test conversation pagination operations."""
def test_pagination_with_non_empty_include_ids(self, db_session_with_containers):
def test_pagination_with_non_empty_include_ids(self, db_session_with_containers: Session):
"""
Test that non-empty include_ids filters properly.
@ -204,7 +205,7 @@ class TestConversationServicePagination:
returned_ids = {conversation.id for conversation in result.data}
assert returned_ids == {conversations[0].id, conversations[1].id}
def test_pagination_with_empty_exclude_ids(self, db_session_with_containers):
def test_pagination_with_empty_exclude_ids(self, db_session_with_containers: Session):
"""
Test that empty exclude_ids doesn't filter.
@ -237,7 +238,7 @@ class TestConversationServicePagination:
# Assert
assert len(result.data) == len(conversations)
def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers):
def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers: Session):
"""
Test that non-empty exclude_ids filters properly.
@ -271,7 +272,7 @@ class TestConversationServicePagination:
returned_ids = {conversation.id for conversation in result.data}
assert returned_ids == {conversations[2].id}
def test_pagination_with_sorting_descending(self, db_session_with_containers):
def test_pagination_with_sorting_descending(self, db_session_with_containers: Session):
"""
Test pagination with descending sort order.
@ -316,7 +317,7 @@ class TestConversationServiceMessageCreation:
within conversations.
"""
def test_pagination_by_first_id_without_first_id(self, db_session_with_containers):
def test_pagination_by_first_id_without_first_id(self, db_session_with_containers: Session):
"""
Test message pagination without specifying first_id.
@ -354,7 +355,7 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == 3 # All 3 messages returned
assert result.has_more is False # No more messages available (3 < limit of 10)
def test_pagination_by_first_id_with_first_id(self, db_session_with_containers):
def test_pagination_by_first_id_with_first_id(self, db_session_with_containers: Session):
"""
Test message pagination with first_id specified.
@ -399,7 +400,9 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == 2 # Only 2 messages returned after first_id
assert result.has_more is False # No more messages available (2 < limit of 10)
def test_pagination_by_first_id_raises_error_when_first_message_not_found(self, db_session_with_containers):
def test_pagination_by_first_id_raises_error_when_first_message_not_found(
self, db_session_with_containers: Session
):
"""
Test that FirstMessageNotExistsError is raised when first_id doesn't exist.
@ -424,7 +427,7 @@ class TestConversationServiceMessageCreation:
limit=10,
)
def test_pagination_with_has_more_flag(self, db_session_with_containers):
def test_pagination_with_has_more_flag(self, db_session_with_containers: Session):
"""
Test that has_more flag is correctly set when there are more messages.
@ -463,7 +466,7 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == limit # Extra message should be removed
assert result.has_more is True # Flag should be set
def test_pagination_with_ascending_order(self, db_session_with_containers):
def test_pagination_with_ascending_order(self, db_session_with_containers: Session):
"""
Test message pagination with ascending order.
@ -512,7 +515,7 @@ class TestConversationServiceSummarization:
"""
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers):
def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers: Session):
"""
Test successful auto-generation of conversation name.
@ -552,7 +555,7 @@ class TestConversationServiceSummarization:
app_model.tenant_id, first_message.query, conversation.id, app_model.id
)
def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers):
def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers: Session):
"""
Test that MessageNotExistsError is raised when conversation has no messages.
@ -571,7 +574,9 @@ class TestConversationServiceSummarization:
ConversationService.auto_generate_name(app_model, conversation)
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_llm_generator, db_session_with_containers):
def test_auto_generate_name_handles_llm_failure_gracefully(
self, mock_llm_generator, db_session_with_containers: Session
):
"""
Test that LLM generation failures are suppressed and don't crash.
@ -604,7 +609,7 @@ class TestConversationServiceSummarization:
assert conversation.name == original_name # Name remains unchanged
@patch("services.conversation_service.naive_utc_now")
def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers):
def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers: Session):
"""
Test renaming conversation with manual name.
@ -638,7 +643,7 @@ class TestConversationServiceSummarization:
assert conversation.updated_at == mock_time
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers):
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers: Session):
"""
Test rename delegates to auto_generate_name when auto_generate is True.
@ -682,7 +687,9 @@ class TestConversationServiceMessageAnnotation:
@patch("services.annotation_service.add_annotation_to_index_task")
@patch("services.annotation_service.current_account_with_tenant")
def test_create_annotation_from_message(self, mock_current_account, mock_add_task, db_session_with_containers):
def test_create_annotation_from_message(
self, mock_current_account, mock_add_task, db_session_with_containers: Session
):
"""
Test creating annotation from existing message.
@ -721,7 +728,9 @@ class TestConversationServiceMessageAnnotation:
@patch("services.annotation_service.add_annotation_to_index_task")
@patch("services.annotation_service.current_account_with_tenant")
def test_create_annotation_without_message(self, mock_current_account, mock_add_task, db_session_with_containers):
def test_create_annotation_without_message(
self, mock_current_account, mock_add_task, db_session_with_containers: Session
):
"""
Test creating standalone annotation without message.
@ -753,7 +762,7 @@ class TestConversationServiceMessageAnnotation:
@patch("services.annotation_service.add_annotation_to_index_task")
@patch("services.annotation_service.current_account_with_tenant")
def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers):
def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers: Session):
"""
Test updating an existing annotation.
@ -800,7 +809,7 @@ class TestConversationServiceMessageAnnotation:
mock_add_task.delay.assert_not_called()
@patch("services.annotation_service.current_account_with_tenant")
def test_get_annotation_list(self, mock_current_account, db_session_with_containers):
def test_get_annotation_list(self, mock_current_account, db_session_with_containers: Session):
"""
Test retrieving paginated annotation list.
@ -836,7 +845,7 @@ class TestConversationServiceMessageAnnotation:
assert result_total == 5
@patch("services.annotation_service.current_account_with_tenant")
def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers):
def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers: Session):
"""
Test retrieving annotations with keyword filtering.
@ -885,7 +894,7 @@ class TestConversationServiceMessageAnnotation:
@patch("services.annotation_service.add_annotation_to_index_task")
@patch("services.annotation_service.current_account_with_tenant")
def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers):
def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers: Session):
"""
Test direct annotation insertion without message reference.
@ -919,7 +928,7 @@ class TestConversationServiceExport:
Tests retrieving conversation data for export purposes.
"""
def test_get_conversation_success(self, db_session_with_containers):
def test_get_conversation_success(self, db_session_with_containers: Session):
"""Test successful retrieval of conversation."""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -937,7 +946,7 @@ class TestConversationServiceExport:
# Assert
assert result == conversation
def test_get_conversation_not_found(self, db_session_with_containers):
def test_get_conversation_not_found(self, db_session_with_containers: Session):
"""Test ConversationNotExistsError when conversation doesn't exist."""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -949,7 +958,7 @@ class TestConversationServiceExport:
ConversationService.get_conversation(app_model=app_model, conversation_id=str(uuid4()), user=user)
@patch("services.annotation_service.current_account_with_tenant")
def test_export_annotation_list(self, mock_current_account, db_session_with_containers):
def test_export_annotation_list(self, mock_current_account, db_session_with_containers: Session):
"""Test exporting all annotations for an app."""
# Arrange
app_model, account = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -977,7 +986,7 @@ class TestConversationServiceExport:
# Assert
assert len(result) == 10
def test_get_message_success(self, db_session_with_containers):
def test_get_message_success(self, db_session_with_containers: Session):
"""Test successful retrieval of a message."""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -1001,7 +1010,7 @@ class TestConversationServiceExport:
# Assert
assert result == message
def test_get_message_not_found(self, db_session_with_containers):
def test_get_message_not_found(self, db_session_with_containers: Session):
"""Test MessageNotExistsError when message doesn't exist."""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
@ -1012,7 +1021,7 @@ class TestConversationServiceExport:
with pytest.raises(MessageNotExistsError):
MessageService.get_message(app_model=app_model, user=user, message_id=str(uuid4()))
def test_get_conversation_for_end_user(self, db_session_with_containers):
def test_get_conversation_for_end_user(self, db_session_with_containers: Session):
"""
Test retrieving conversation created by end user via API.
@ -1038,7 +1047,7 @@ class TestConversationServiceExport:
assert result == conversation
@patch("services.conversation_service.delete_conversation_related_data")
def test_delete_conversation(self, mock_delete_task, db_session_with_containers):
def test_delete_conversation(self, mock_delete_task, db_session_with_containers: Session):
"""
Test conversation deletion with async cleanup.
@ -1071,7 +1080,7 @@ class TestConversationServiceExport:
mock_delete_task.delay.assert_called_once_with(conversation_id)
@patch("services.conversation_service.delete_conversation_related_data")
def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers):
def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers: Session):
"""
Test deletion is denied when conversation belongs to a different account.
"""
@ -1102,7 +1111,7 @@ class TestConversationServiceExport:
mock_delete_task.delay.assert_not_called()
@patch("services.conversation_service.delete_conversation_related_data")
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers):
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers: Session):
"""
Test that delete propagates exceptions and does not trigger the cleanup task.

View File

@ -5,7 +5,8 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import sessionmaker
from flask import Flask
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
@ -149,7 +150,7 @@ class ConversationServiceVariableIntegrationFactory:
@pytest.fixture
def real_conversation_service_session_factory(flask_app_with_containers):
def real_conversation_service_session_factory(flask_app_with_containers: Flask):
del flask_app_with_containers
real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
@ -162,7 +163,7 @@ def real_conversation_service_session_factory(flask_app_with_containers):
class TestConversationServiceVariables:
def test_get_conversational_variable_success(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -200,7 +201,7 @@ class TestConversationServiceVariables:
assert result.has_more is False
def test_get_conversational_variable_with_last_id(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -242,7 +243,7 @@ class TestConversationServiceVariables:
assert result.has_more is False
def test_get_conversational_variable_last_id_not_found_raises_error(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -259,7 +260,7 @@ class TestConversationServiceVariables:
)
def test_get_conversational_variable_sets_has_more(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -287,7 +288,7 @@ class TestConversationServiceVariables:
assert result.has_more is True
def test_update_conversation_variable_success(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -320,7 +321,7 @@ class TestConversationServiceVariables:
assert result["updated_at"] == updated_at
def test_update_conversation_variable_not_found_raises_error(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -337,7 +338,7 @@ class TestConversationServiceVariables:
)
def test_update_conversation_variable_type_mismatch_raises_error(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -360,7 +361,7 @@ class TestConversationServiceVariables:
)
def test_update_conversation_variable_integer_number_compatibility(
self, db_session_with_containers, real_conversation_service_session_factory
self, db_session_with_containers: Session, real_conversation_service_session_factory
):
del real_conversation_service_session_factory
factory = ConversationServiceVariableIntegrationFactory
@ -390,7 +391,7 @@ class TestConversationServiceVariables:
class TestConversationServicePaginationWithContainers:
def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers):
def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
@ -404,7 +405,7 @@ class TestConversationServicePaginationWithContainers:
invoke_from=InvokeFrom.WEB_APP,
)
def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers):
def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
base_time = datetime(2024, 1, 1, 8, 0, 0)
@ -442,7 +443,7 @@ class TestConversationServicePaginationWithContainers:
assert newest.id != middle.id
assert [conversation.id for conversation in result.data] == [oldest.id]
def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers):
def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha")
@ -462,7 +463,7 @@ class TestConversationServicePaginationWithContainers:
assert alpha.id != beta.id
assert [conversation.id for conversation in result.data] == [gamma.id]
def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers):
def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
end_user = factory.create_end_user(db_session_with_containers, app)
@ -493,7 +494,7 @@ class TestConversationServicePaginationWithContainers:
assert account_conversation.id != end_user_conversation.id
assert [conversation.id for conversation in result.data] == [end_user_conversation.id]
def test_pagination_filters_to_account_console_source(self, db_session_with_containers):
def test_pagination_filters_to_account_console_source(self, db_session_with_containers: Session):
factory = ConversationServiceVariableIntegrationFactory
app, account = factory.create_app_and_account(db_session_with_containers)
end_user = factory.create_end_user(db_session_with_containers, app)

View File

@ -3,7 +3,7 @@
from uuid import uuid4
import pytest
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from extensions.ext_database import db
from graphon.variables import StringVariable
@ -13,7 +13,12 @@ from services.conversation_variable_updater import ConversationVariableNotFoundE
class TestConversationVariableUpdater:
def _create_conversation_variable(
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
self,
db_session_with_containers: Session,
*,
conversation_id: str,
variable: StringVariable,
app_id: str | None = None,
) -> ConversationVariable:
row = ConversationVariable(
id=variable.id,
@ -25,7 +30,7 @@ class TestConversationVariableUpdater:
db_session_with_containers.commit()
return row
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers: Session):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
self._create_conversation_variable(
@ -42,7 +47,7 @@ class TestConversationVariableUpdater:
assert row is not None
assert row.data == updated_variable.model_dump_json()
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers: Session):
conversation_id = str(uuid4())
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
@ -50,7 +55,7 @@ class TestConversationVariableUpdater:
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
updater.update(conversation_id=conversation_id, variable=variable)
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers: Session):
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
result = updater.flush()

View File

@ -3,6 +3,7 @@
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
@ -14,7 +15,7 @@ class TestCreditPoolService:
def _create_tenant_id(self) -> str:
return str(uuid4())
def test_create_default_pool(self, db_session_with_containers):
def test_create_default_pool(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
@ -25,7 +26,7 @@ class TestCreditPoolService:
assert pool.quota_used == 0
assert pool.quota_limit > 0
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers):
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
@ -35,17 +36,17 @@ class TestCreditPoolService:
assert result.tenant_id == tenant_id
assert result.pool_type == ProviderQuotaType.TRIAL
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session):
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
assert result is None
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers):
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session):
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
assert result is False
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers):
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
@ -53,7 +54,7 @@ class TestCreditPoolService:
assert result is True
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers):
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
# Exhaust credits
@ -64,11 +65,11 @@ class TestCreditPoolService:
assert result is False
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers):
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session):
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers):
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
pool.quota_used = pool.quota_limit
@ -77,7 +78,7 @@ class TestCreditPoolService:
with pytest.raises(QuotaExceededError, match="No credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers):
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
credits_required = 10
@ -89,7 +90,7 @@ class TestCreditPoolService:
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers):
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5

View File

@ -8,6 +8,7 @@ checks with testcontainers-backed infrastructure instead of database-chain mocks
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
@ -107,7 +108,7 @@ class DatasetPermissionTestDataFactory:
class TestDatasetPermissionServiceGetPartialMemberList:
"""Verify partial-member list reads against persisted DatasetPermission rows."""
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers):
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers: Session):
"""
Test retrieving partial member list with multiple members.
"""
@ -138,7 +139,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
assert set(result) == set(expected_account_ids)
assert len(result) == 3
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers):
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers: Session):
"""
Test retrieving partial member list with single member.
"""
@ -160,7 +161,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
assert set(result) == set(expected_account_ids)
assert len(result) == 1
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers):
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers: Session):
"""
Test retrieving partial member list when no members exist.
"""
@ -179,7 +180,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
class TestDatasetPermissionServiceUpdatePartialMemberList:
"""Verify partial-member list updates against persisted DatasetPermission rows."""
def test_update_partial_member_list_add_new_members(self, db_session_with_containers):
def test_update_partial_member_list_add_new_members(self, db_session_with_containers: Session):
"""
Test adding new partial members to a dataset.
"""
@ -203,7 +204,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert set(result) == {member_1.id, member_2.id}
def test_update_partial_member_list_replace_existing(self, db_session_with_containers):
def test_update_partial_member_list_replace_existing(self, db_session_with_containers: Session):
"""
Test replacing existing partial members with new ones.
"""
@ -239,7 +240,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert set(result) == {new_member_1.id, new_member_2.id}
def test_update_partial_member_list_empty_list(self, db_session_with_containers):
def test_update_partial_member_list_empty_list(self, db_session_with_containers: Session):
"""
Test updating with empty member list (clearing all members).
"""
@ -264,7 +265,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers):
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers: Session):
"""
Test error handling and rollback on database error.
"""
@ -313,7 +314,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
class TestDatasetPermissionServiceClearPartialMemberList:
"""Verify partial-member clearing against persisted DatasetPermission rows."""
def test_clear_partial_member_list_success(self, db_session_with_containers):
def test_clear_partial_member_list_success(self, db_session_with_containers: Session):
"""
Test successful clearing of partial member list.
"""
@ -338,7 +339,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_clear_partial_member_list_empty_list(self, db_session_with_containers):
def test_clear_partial_member_list_empty_list(self, db_session_with_containers: Session):
"""
Test clearing partial member list when no members exist.
"""
@ -353,7 +354,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert result == []
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers):
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers: Session):
"""
Test error handling and rollback on database error.
"""
@ -398,7 +399,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
class TestDatasetServiceCheckDatasetPermission:
"""Verify dataset access checks against persisted partial-member permissions."""
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers):
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers: Session):
"""Test that users from different tenants cannot access dataset."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
@ -410,7 +411,7 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other_user)
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers):
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers: Session):
"""Test that tenant owners can access any dataset regardless of permission level."""
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
@ -423,7 +424,7 @@ class TestDatasetServiceCheckDatasetPermission:
DatasetService.check_dataset_permission(dataset, owner)
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers):
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers: Session):
"""Test ONLY_ME permission allows only the dataset creator to access."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
@ -433,7 +434,7 @@ class TestDatasetServiceCheckDatasetPermission:
DatasetService.check_dataset_permission(dataset, creator)
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers):
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers: Session):
"""Test ONLY_ME permission denies access to non-creators."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(dataset, other)
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers):
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers: Session):
"""Test ALL_TEAM permission allows any team member to access the dataset."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
@ -460,7 +461,9 @@ class TestDatasetServiceCheckDatasetPermission:
DatasetService.check_dataset_permission(dataset, member)
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
def test_check_dataset_permission_partial_members_with_permission_success(
self, db_session_with_containers: Session
):
"""
Test that user with explicit permission can access partial_members dataset.
"""
@ -485,7 +488,9 @@ class TestDatasetServiceCheckDatasetPermission:
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
assert user.id in permissions
def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers):
def test_check_dataset_permission_partial_members_without_permission_error(
self, db_session_with_containers: Session
):
"""
Test error when user without permission tries to access partial_members dataset.
"""
@ -506,7 +511,7 @@ class TestDatasetServiceCheckDatasetPermission:
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers):
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers: Session):
"""Test PARTIAL_TEAM permission allows creator to access without explicit permission."""
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)

View File

@ -712,7 +712,7 @@ class TestDatasetServiceRetrievalConfiguration:
class TestDocumentServicePauseRecoverRetry:
"""Tests for pause/recover/retry orchestration using real DB and Redis."""
def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"):
def _create_indexing_document(self, db_session_with_containers: Session, indexing_status="indexing"):
factory = DatasetServiceIntegrationDataFactory
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
@ -721,7 +721,7 @@ class TestDocumentServicePauseRecoverRetry:
db_session_with_containers.commit()
return doc, account
def test_pause_document_success(self, db_session_with_containers):
def test_pause_document_success(self, db_session_with_containers: Session):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService
@ -740,7 +740,7 @@ class TestDocumentServicePauseRecoverRetry:
assert redis_client.get(cache_key) is not None
redis_client.delete(cache_key)
def test_pause_document_invalid_status_error(self, db_session_with_containers):
def test_pause_document_invalid_status_error(self, db_session_with_containers: Session):
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -751,7 +751,7 @@ class TestDocumentServicePauseRecoverRetry:
with pytest.raises(DocumentIndexingError):
DocumentService.pause_document(doc)
def test_recover_document_success(self, db_session_with_containers):
def test_recover_document_success(self, db_session_with_containers: Session):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService
@ -775,7 +775,7 @@ class TestDocumentServicePauseRecoverRetry:
assert redis_client.get(cache_key) is None
recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id)
def test_retry_document_indexing_success(self, db_session_with_containers):
def test_retry_document_indexing_success(self, db_session_with_containers: Session):
from extensions.ext_redis import redis_client
from services.dataset_service import DocumentService

View File

@ -6,6 +6,7 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin
from services.dataset_service import DatasetService
@ -48,7 +49,7 @@ class TestDatasetServiceCreateRagPipelineDataset:
permission="only_me",
)
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers):
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers: Session):
tenant, _ = self._create_tenant_and_account(db_session_with_containers)
mock_user = Mock(id=None)

View File

@ -3,6 +3,8 @@
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
@ -101,7 +103,7 @@ class DatasetDeleteIntegrationDataFactory:
class TestDatasetServiceDeleteDataset:
"""Integration coverage for DatasetService.delete_dataset using testcontainers."""
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session):
"""Delete a dataset with documents and dispatch cleanup through the real signal handler."""
# Arrange
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
@ -144,7 +146,7 @@ class TestDatasetServiceDeleteDataset:
dataset.pipeline_id,
)
def test_delete_empty_dataset_success(self, db_session_with_containers):
def test_delete_empty_dataset_success(self, db_session_with_containers: Session):
"""Delete an empty dataset without scheduling cleanup when both gating fields are absent."""
# Arrange
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
@ -172,7 +174,7 @@ class TestDatasetServiceDeleteDataset:
assert db_session_with_containers.get(Dataset, dataset.id) is None
clean_dataset_delay.assert_not_called()
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session):
"""Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves."""
# Arrange
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
@ -200,7 +202,7 @@ class TestDatasetServiceDeleteDataset:
assert db_session_with_containers.get(Dataset, dataset.id) is None
clean_dataset_delay.assert_not_called()
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers):
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers: Session):
"""Delete a dataset without cleanup when indexing exists but doc_form resolves to None."""
# Arrange
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
@ -228,7 +230,7 @@ class TestDatasetServiceDeleteDataset:
assert db_session_with_containers.get(Dataset, dataset.id) is None
clean_dataset_delay.assert_not_called()
def test_delete_dataset_not_found(self, db_session_with_containers):
def test_delete_dataset_not_found(self, db_session_with_containers: Session):
"""Return False without scheduling cleanup when the target dataset does not exist."""
# Arrange
owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -363,7 +364,7 @@ class TestDatasetServicePermissionsAndLifecycle:
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers):
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers: Flask):
with flask_app_with_containers.app_context():
with pytest.raises(NotFound, match="Dataset not found"):
DatasetService.update_dataset_api_status(str(uuid4()), True)
@ -473,7 +474,7 @@ class TestDatasetCollectionBindingServiceIntegration:
assert persisted.type == "dataset"
assert persisted.collection_name
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers):
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers: Flask):
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4()))

View File

@ -6,6 +6,7 @@ from datetime import UTC, datetime, timedelta
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from graphon.enums import WorkflowExecutionStatus
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
@ -46,7 +47,7 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.commit()
return run
def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None:
def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> None:
archive_log = WorkflowArchiveLog(
tenant_id=run.tenant_id,
app_id=run.app_id,
@ -72,7 +73,7 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.add(archive_log)
db_session_with_containers.commit()
def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers):
def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers: Session):
deleter = ArchivedWorkflowRunDeletion()
missing_run_id = str(uuid4())
@ -81,7 +82,7 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is False
assert result.error == f"Workflow run {missing_run_id} not found"
def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers):
def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers: Session):
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
@ -95,7 +96,7 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is False
assert result.error == f"Workflow run {run.id} is not archived"
def test_delete_batch_uses_repo(self, db_session_with_containers):
def test_delete_batch_uses_repo(self, db_session_with_containers: Session):
tenant_id = str(uuid4())
base_time = datetime.now(UTC)
run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time)
@ -124,7 +125,7 @@ class TestArchivedWorkflowRunDeletion:
).all()
assert remaining_runs == []
def test_delete_run_calls_repo(self, db_session_with_containers):
def test_delete_run_calls_repo(self, db_session_with_containers: Session):
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
@ -142,7 +143,7 @@ class TestArchivedWorkflowRunDeletion:
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
assert deleted_run is None
def test_delete_run_dry_run(self, db_session_with_containers):
def test_delete_run_dry_run(self, db_session_with_containers: Session):
"""Dry run should return success without actually deleting."""
tenant_id = str(uuid4())
run = self._create_workflow_run(
@ -161,7 +162,7 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.expire_all()
assert db_session_with_containers.get(WorkflowRun, run_id) is not None
def test_delete_run_exception_returns_error(self, db_session_with_containers):
def test_delete_run_exception_returns_error(self, db_session_with_containers: Session):
"""Exception during deletion should return failure result."""
from unittest.mock import MagicMock, patch
@ -183,7 +184,7 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is False
assert result.error == "Database error"
def test_delete_by_run_id_success(self, db_session_with_containers):
def test_delete_by_run_id_success(self, db_session_with_containers: Session):
"""Successfully delete an archived workflow run by ID."""
tenant_id = str(uuid4())
base_time = datetime.now(UTC)
@ -202,7 +203,7 @@ class TestArchivedWorkflowRunDeletion:
db_session_with_containers.expunge_all()
assert db_session_with_containers.get(WorkflowRun, run_id) is None
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers):
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers: Session):
"""_get_workflow_run_repo should return a cached repo on subsequent calls."""
deleter = ArchivedWorkflowRunDeletion()

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account, Tenant, TenantAccountJoin
@ -102,7 +103,7 @@ class TestEndUserServiceGetOrCreateEndUser:
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory):
def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers: Session, factory):
"""Test getting or creating end user with custom user_id."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -118,7 +119,7 @@ class TestEndUserServiceGetOrCreateEndUser:
assert result.type == InvokeFrom.SERVICE_API
assert result.is_anonymous is False
def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory):
def test_get_or_create_end_user_without_user_id(self, db_session_with_containers: Session, factory):
"""Test getting or creating end user without user_id uses default session."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -131,7 +132,7 @@ class TestEndUserServiceGetOrCreateEndUser:
# Verify _is_anonymous is set correctly (property always returns False)
assert result._is_anonymous is True
def test_get_existing_end_user(self, db_session_with_containers, factory):
def test_get_existing_end_user(self, db_session_with_containers: Session, factory):
"""Test retrieving an existing end user."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -167,7 +168,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_create_end_user_service_api_type(self, db_session_with_containers, factory):
def test_create_end_user_service_api_type(self, db_session_with_containers: Session, factory):
"""Test creating new end user with SERVICE_API type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -189,7 +190,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result.app_id == app_id
assert result.session_id == user_id
def test_create_end_user_web_app_type(self, db_session_with_containers, factory):
def test_create_end_user_web_app_type(self, db_session_with_containers: Session, factory):
"""Test creating new end user with WEB_APP type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -209,7 +210,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result.type == InvokeFrom.WEB_APP
@patch("services.end_user_service.logger")
def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory):
def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers: Session, factory):
"""Test upgrading legacy end user with different type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -243,7 +244,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert "Upgrading legacy EndUser" in log_call
@patch("services.end_user_service.logger")
def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory):
def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers: Session, factory):
"""Test retrieving existing end user with matching type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -272,7 +273,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result.type == InvokeFrom.SERVICE_API
mock_logger.info.assert_not_called()
def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory):
def test_create_anonymous_user_with_default_session(self, db_session_with_containers: Session, factory):
"""Test creating anonymous user when user_id is None."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -293,7 +294,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result._is_anonymous is True
assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory):
def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers: Session, factory):
"""Test that query ordering prioritizes records with matching type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -328,7 +329,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
assert result.id == matching.id
assert result.id != non_matching.id
def test_external_user_id_matches_session_id(self, db_session_with_containers, factory):
def test_external_user_id_matches_session_id(self, db_session_with_containers: Session, factory):
"""Test that external_user_id is set to match session_id."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -357,7 +358,9 @@ class TestEndUserServiceGetOrCreateEndUserByType:
InvokeFrom.DEBUGGER,
],
)
def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory):
def test_create_end_user_with_different_invoke_types(
self, db_session_with_containers: Session, invoke_type, factory
):
"""Test creating end users with different InvokeFrom types."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
@ -385,7 +388,7 @@ class TestEndUserServiceGetEndUserById:
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory):
def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers: Session, factory):
app = factory.create_app_and_account(db_session_with_containers)
existing_user = factory.create_end_user(
db_session_with_containers,
@ -404,7 +407,7 @@ class TestEndUserServiceGetEndUserById:
assert result is not None
assert result.id == existing_user.id
def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory):
def test_get_end_user_by_id_returns_none(self, db_session_with_containers: Session, factory):
app = factory.create_app_and_account(db_session_with_containers)
result = EndUserService.get_end_user_by_id(
@ -423,7 +426,7 @@ class TestEndUserServiceCreateBatch:
def factory(self):
return TestEndUserServiceFactory()
def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3):
def _create_multiple_apps(self, db_session_with_containers: Session, factory, count: int = 3):
"""Create multiple apps under the same tenant."""
first_app = factory.create_app_and_account(db_session_with_containers)
tenant_id = first_app.tenant_id
@ -452,13 +455,13 @@ class TestEndUserServiceCreateBatch:
all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all()
return tenant_id, all_apps
def test_create_batch_empty_app_ids(self, db_session_with_containers):
def test_create_batch_empty_app_ids(self, db_session_with_containers: Session):
result = EndUserService.create_end_user_batch(
type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1"
)
assert result == {}
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory):
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
app_ids = [a.id for a in apps]
user_id = f"user-{uuid4()}"
@ -473,7 +476,7 @@ class TestEndUserServiceCreateBatch:
assert result[app_id].session_id == user_id
assert result[app_id].type == InvokeFrom.SERVICE_API
def test_create_batch_default_session_id(self, db_session_with_containers, factory):
def test_create_batch_default_session_id(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [a.id for a in apps]
@ -486,7 +489,7 @@ class TestEndUserServiceCreateBatch:
assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
assert end_user._is_anonymous is True
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory):
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id]
user_id = f"user-{uuid4()}"
@ -497,7 +500,7 @@ class TestEndUserServiceCreateBatch:
assert len(result) == 2
def test_create_batch_returns_existing_users(self, db_session_with_containers, factory):
def test_create_batch_returns_existing_users(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
app_ids = [a.id for a in apps]
user_id = f"user-{uuid4()}"
@ -516,7 +519,7 @@ class TestEndUserServiceCreateBatch:
for app_id in app_ids:
assert first_result[app_id].id == second_result[app_id].id
def test_create_batch_partial_existing_users(self, db_session_with_containers, factory):
def test_create_batch_partial_existing_users(self, db_session_with_containers: Session, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
user_id = f"user-{uuid4()}"
@ -545,7 +548,7 @@ class TestEndUserServiceCreateBatch:
"invoke_type",
[InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER],
)
def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory):
def test_create_batch_all_invoke_types(self, db_session_with_containers: Session, invoke_type, factory):
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1)
user_id = f"user-{uuid4()}"

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from services.feature_service import (
@ -81,7 +82,7 @@ class TestFeatureService:
fake = Faker()
return fake.uuid4()
def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful feature retrieval with billing and enterprise enabled.
@ -156,7 +157,7 @@ class TestFeatureService:
tenant_id
)
def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_sandbox_plan(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test feature retrieval for sandbox plan with specific limitations.
@ -222,7 +223,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_knowledge_rate_limit_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful knowledge rate limit retrieval with billing enabled.
@ -255,7 +258,7 @@ class TestFeatureService:
tenant_id
)
def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful system features retrieval with enterprise and marketplace enabled.
@ -332,7 +335,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_unauthenticated(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval for an unauthenticated user.
@ -386,7 +391,9 @@ class TestFeatureService:
# Marketplace should be visible
assert result.enable_marketplace is True
def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_basic_config(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with basic configuration (no enterprise).
@ -436,7 +443,9 @@ class TestFeatureService:
# Verify plugin package size (uses default value from dify_config)
assert result.max_plugin_package_size == 15728640
def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_billing_disabled(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval when billing is disabled.
@ -492,7 +501,7 @@ class TestFeatureService:
assert result.webapp_copyright_enabled is False
def test_get_knowledge_rate_limit_billing_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test knowledge rate limit retrieval when billing is disabled.
@ -523,7 +532,9 @@ class TestFeatureService:
# Verify no billing service calls
mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called()
def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_enterprise_only(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with enterprise enabled but billing disabled.
@ -583,7 +594,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_not_called()
def test_get_system_features_enterprise_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval when enterprise is disabled.
@ -640,7 +651,7 @@ class TestFeatureService:
# Verify no enterprise service calls
mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called()
def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test feature retrieval without tenant ID (billing disabled).
@ -686,7 +697,9 @@ class TestFeatureService:
# Verify no billing service calls
mock_external_service_dependencies["billing_service"].get_info.assert_not_called()
def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_partial_billing_info(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with partial billing information.
@ -746,7 +759,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_edge_case_vector_space(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case vector space configuration.
@ -807,7 +822,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_system_features_edge_case_webapp_auth(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with edge case webapp auth configuration.
@ -863,7 +878,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_edge_case_members_quota(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case members quota configuration.
@ -924,7 +941,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_plugin_installation_permission_scopes(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with different plugin installation permission scopes.
@ -1023,7 +1040,7 @@ class TestFeatureService:
assert result.plugin_installation_permission.restrict_to_marketplace_only is True
def test_get_features_workspace_members_missing(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval when workspace members info is missing from enterprise.
@ -1064,7 +1081,9 @@ class TestFeatureService:
tenant_id
)
def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_license_inactive(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with inactive license.
@ -1117,7 +1136,7 @@ class TestFeatureService:
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_system_features_partial_enterprise_info(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with partial enterprise information.
@ -1186,7 +1205,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_edge_case_limits(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case limit values.
@ -1244,7 +1265,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_system_features_edge_case_protocols(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with edge case protocol values.
@ -1297,7 +1318,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_features_edge_case_education(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case education configuration.
@ -1353,7 +1376,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_license_limitation_model_is_available(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test LicenseLimitationModel.is_available method with various scenarios.
@ -1394,7 +1417,7 @@ class TestFeatureService:
assert exact_limit.is_available(3) is True
def test_get_features_workspace_members_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval when workspace members are disabled in enterprise.
@ -1433,7 +1456,9 @@ class TestFeatureService:
# Verify mock interactions
mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id)
def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_system_features_license_expired(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with expired license.
@ -1486,7 +1511,7 @@ class TestFeatureService:
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_docs_processing(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case document processing configuration.
@ -1544,7 +1569,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_system_features_edge_case_branding(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features retrieval with edge case branding configuration.
@ -1606,7 +1631,7 @@ class TestFeatureService:
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_annotation_quota(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case annotation quota configuration.
@ -1668,7 +1693,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_features_edge_case_documents_upload(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with edge case documents upload settings.
@ -1733,7 +1758,7 @@ class TestFeatureService:
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
def test_get_system_features_edge_case_license_lost(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test system features with lost license status.
@ -1784,7 +1809,7 @@ class TestFeatureService:
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
def test_get_features_edge_case_education_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test feature retrieval with education feature disabled.

View File

@ -6,6 +6,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from configs import dify_config
from core.workflow.human_input_adapter import (
@ -88,7 +89,7 @@ class TestDeliveryTestRegistry:
with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."):
registry.dispatch(context=context, method=method)
def test_default(self, flask_app_with_containers, db_session_with_containers):
def test_default(self, flask_app_with_containers, db_session_with_containers: Session):
registry = DeliveryTestRegistry.default()
assert len(registry._handlers) == 1
assert isinstance(registry._handlers[0], EmailDeliveryTestHandler)
@ -260,7 +261,7 @@ class TestEmailDeliveryTestHandler:
)
assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"]
def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers):
def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
account = Account(name="Test User", email="member@example.com")
db_session_with_containers.add(account)
@ -282,7 +283,7 @@ class TestEmailDeliveryTestHandler:
)
assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"]
def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers):
def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers: Session):
tenant_id = str(uuid4())
account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com")
account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com")

View File

@ -165,7 +165,7 @@ class TestMessagesCleanServiceIntegration:
return app
def _create_conversation(self, db_session_with_containers: Session, app):
def _create_conversation(self, db_session_with_containers: Session, app: App):
"""Helper to create a conversation."""
conversation = Conversation(
app_id=app.id,

View File

@ -5,6 +5,7 @@ from uuid import uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from models.dataset import Dataset, DatasetMetadataBinding, Document
from models.enums import DataSourceType, DocumentCreatedFrom
@ -65,7 +66,7 @@ class TestMetadataPartialUpdate:
yield account
def test_partial_update_merges_metadata(
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
):
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
@ -92,7 +93,7 @@ class TestMetadataPartialUpdate:
assert updated_doc.doc_metadata["new_key"] == "new_value"
def test_full_update_replaces_metadata(
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
):
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
@ -119,7 +120,7 @@ class TestMetadataPartialUpdate:
assert "existing_key" not in updated_doc.doc_metadata
def test_partial_update_skips_existing_binding(
self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, user_id, mock_current_account
):
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(
@ -159,7 +160,7 @@ class TestMetadataPartialUpdate:
assert len(bindings) == 1
def test_rollback_called_on_commit_failure(
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
):
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
document = _create_document(

View File

@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from models.model import OAuthProviderApp
@ -25,7 +26,7 @@ from services.oauth_server import (
class TestOAuthServerServiceGetProviderApp:
"""DB-backed tests for get_oauth_provider_app."""
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
def _create_oauth_provider_app(self, db_session_with_containers: Session, *, client_id: str) -> OAuthProviderApp:
app = OAuthProviderApp(
app_icon="icon.png",
client_id=client_id,
@ -38,7 +39,7 @@ class TestOAuthServerServiceGetProviderApp:
db_session_with_containers.commit()
return app
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers: Session):
client_id = f"client-{uuid4()}"
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
@ -48,7 +49,7 @@ class TestOAuthServerServiceGetProviderApp:
assert result.client_id == client_id
assert result.id == created.id
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers: Session):
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
assert result is None

View File

@ -8,6 +8,7 @@ from datetime import datetime
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from models.workflow import WorkflowPause, WorkflowRun
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
@ -39,7 +40,7 @@ class TestWorkflowRunRestore:
assert result["created_at"].month == 1
assert result["name"] == "test"
def test_restore_table_records_returns_rowcount(self, db_session_with_containers):
def test_restore_table_records_returns_rowcount(self, db_session_with_containers: Session):
"""Restore should return inserted rowcount."""
restore = WorkflowRunRestore()
record_id = str(uuid4())
@ -65,7 +66,7 @@ class TestWorkflowRunRestore:
restored_pause = db_session_with_containers.scalar(select(WorkflowPause).where(WorkflowPause.id == record_id))
assert restored_pause is not None
def test_restore_table_records_unknown_table(self, db_session_with_containers):
def test_restore_table_records_unknown_table(self, db_session_with_containers: Session):
"""Unknown table names should be ignored gracefully."""
restore = WorkflowRunRestore()

Some files were not shown because too many files have changed in this diff Show More