mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
Merge remote-tracking branch 'origin/main'
# Conflicts: # api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py
This commit is contained in:
commit
6c5f6699d2
@ -1,6 +1,6 @@
|
||||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
@ -9,22 +9,24 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep invalidation and mutation flow knowledge in the service layer.
|
||||
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
|
||||
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
|
||||
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
|
||||
- Create or keep feature hooks only for real orchestration or shared domain behavior.
|
||||
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind invalidation in the service-layer mutation definition.
|
||||
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
@ -33,7 +35,7 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- `web/service/use-*.ts`
|
||||
- legacy `web/service/use-*.ts` files when migrating wrappers away
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query and oRPC patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
|
||||
short_description: "Dify TanStack Query, oRPC, and default option patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Thin hook decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
@ -55,9 +56,13 @@ const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration.
|
||||
3. Create or keep feature hooks only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
|
||||
4. Treat `web/service/use-{domain}.ts` as legacy.
|
||||
- Do not create new thin service wrappers for oRPC contracts.
|
||||
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
@ -74,11 +79,37 @@ const invoiceQuery = useQuery({
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
```typescript
|
||||
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
|
||||
```
|
||||
|
||||
## Thin Hook Decision Rule
|
||||
|
||||
Remove thin hooks when they only rename a single oRPC query or mutation helper.
|
||||
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
|
||||
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
|
||||
|
||||
Use:
|
||||
|
||||
```typescript
|
||||
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
|
||||
```
|
||||
|
||||
Keep:
|
||||
|
||||
```typescript
|
||||
const applyTagBindingsMutation = useApplyTagBindingsMutation()
|
||||
```
|
||||
|
||||
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
|
||||
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- oRPC default options
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
@ -35,9 +36,50 @@ function useBadAccessMode(appId: string | undefined) {
|
||||
}
|
||||
```
|
||||
|
||||
## oRPC Default Options
|
||||
|
||||
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
|
||||
|
||||
Place defaults at the query utility creation point in `web/service/client.ts`:
|
||||
|
||||
```typescript
|
||||
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
|
||||
path: ['console'],
|
||||
experimental_defaults: {
|
||||
tags: {
|
||||
create: {
|
||||
mutationOptions: {
|
||||
onSuccess: (tag, _variables, _result, context) => {
|
||||
context.client.setQueryData(
|
||||
consoleQuery.tags.list.queryKey({
|
||||
input: {
|
||||
query: {
|
||||
type: tag.type,
|
||||
},
|
||||
},
|
||||
}),
|
||||
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
|
||||
)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
|
||||
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
|
||||
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
|
||||
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
|
||||
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind invalidation in the service-layer mutation definition.
|
||||
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
|
||||
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
@ -49,7 +91,7 @@ Use:
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Service layer owns cache invalidation.
|
||||
// Feature orchestration owns cache invalidation only when defaults are not enough.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
@ -124,7 +166,7 @@ When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
|
||||
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
|
||||
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
|
||||
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
|
||||
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -6,6 +6,9 @@
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# ESLint suppression file is maintained by autofix.ci pruning.
|
||||
/eslint-suppressions.json
|
||||
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
|
||||
2
.github/actions/setup-web/action.yml
vendored
2
.github/actions/setup-web/action.yml
vendored
@ -4,7 +4,7 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
|
||||
with:
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
|
||||
1
.github/labeler.yml
vendored
1
.github/labeler.yml
vendored
@ -6,5 +6,4 @@ web:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
|
||||
3
.github/workflows/autofix.yml
vendored
3
.github/workflows/autofix.yml
vendored
@ -43,7 +43,6 @@ jobs:
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
@ -114,7 +113,7 @@ jobs:
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Setup web environment
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
if: github.event_name != 'merge_group'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: ESLint autofix
|
||||
|
||||
8
.github/workflows/build-push.yml
vendored
8
.github/workflows/build-push.yml
vendored
@ -74,7 +74,7 @@ jobs:
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@v1
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
@ -84,7 +84,7 @@ jobs:
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: depot/build-push-action@v1
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
context: ${{ matrix.build_context }}
|
||||
@ -124,10 +124,10 @@ jobs:
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Validate Docker image
|
||||
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.build_context }}
|
||||
|
||||
8
.github/workflows/docker-build.yml
vendored
8
.github/workflows/docker-build.yml
vendored
@ -44,10 +44,10 @@ jobs:
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@v1
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: depot/build-push-action@v1
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
@ -71,10 +71,10 @@ jobs:
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
|
||||
2
.github/workflows/main-ci.yml
vendored
2
.github/workflows/main-ci.yml
vendored
@ -69,7 +69,6 @@ jobs:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
@ -83,7 +82,6 @@ jobs:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
|
||||
1
.github/workflows/style.yml
vendored
1
.github/workflows/style.yml
vendored
@ -83,7 +83,6 @@ jobs:
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
1
.github/workflows/tool-test-sdks.yaml
vendored
1
.github/workflows/tool-test-sdks.yaml
vendored
@ -9,7 +9,6 @@ on:
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107
|
||||
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -219,6 +219,9 @@ node_modules
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
# generated API OpenAPI specs
|
||||
packages/contracts/openapi/
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
|
||||
@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
./dify-compose up -d
|
||||
```
|
||||
|
||||
On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory.
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
|
||||
|
||||
#### Seeking help
|
||||
@ -137,7 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
|
||||
@ -113,8 +113,18 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
# Generate a random password that satisfies the password policy.
|
||||
# The iteration limit guards against infinite loops caused by unexpected bugs in valid_password.
|
||||
for _ in range(100):
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
valid_password(new_password)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
click.echo(click.style("Failed to generate a valid password. Please try again.", fg="red"))
|
||||
return
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(
|
||||
|
||||
@ -41,7 +41,8 @@ def guess_file_info_from_response(response: httpx.Response):
|
||||
# Try to extract filename from URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
# Decode percent-encoded characters in the path segment
|
||||
filename = urllib.parse.unquote(os.path.basename(url_path))
|
||||
|
||||
# If filename couldn't be extracted, use Content-Disposition header
|
||||
if not filename:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
@ -8,6 +9,7 @@ from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@ -57,6 +59,7 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
@ -66,22 +69,19 @@ class AppListQuery(BaseModel):
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
def validate_tag_ids(cls, value: list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("Unsupported tag_ids type.")
|
||||
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
if not items:
|
||||
return None
|
||||
|
||||
@ -91,6 +91,26 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
|
||||
normalized: dict[str, str | list[str]] = {}
|
||||
indexed_tag_ids: list[tuple[int, str]] = []
|
||||
|
||||
for key in query_args:
|
||||
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
value = query_args.get(key)
|
||||
if value is not None:
|
||||
normalized[key] = value
|
||||
|
||||
if indexed_tag_ids:
|
||||
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -455,7 +475,7 @@ class AppListApi(Resource):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
args_dict = args.model_dump()
|
||||
|
||||
# get app list
|
||||
|
||||
@ -60,7 +60,8 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50
|
||||
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
|
||||
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -158,8 +159,13 @@ class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
app_ids: str = Field(..., description="Comma-separated app IDs")
|
||||
class WorkflowOnlineUsersPayload(BaseModel):
|
||||
app_ids: list[str] = Field(default_factory=list, description="App IDs")
|
||||
|
||||
@field_validator("app_ids")
|
||||
@classmethod
|
||||
def normalize_app_ids(cls, app_ids: list[str]) -> list[str]:
|
||||
return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip()))
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
@ -186,7 +192,7 @@ reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(WorkflowOnlineUsersPayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@ -1384,19 +1390,19 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
def post(self):
|
||||
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip()))
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.")
|
||||
app_ids = args.app_ids
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS} app_ids are allowed per request.")
|
||||
|
||||
if not app_ids:
|
||||
return {"data": []}
|
||||
@ -1404,13 +1410,24 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
|
||||
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]
|
||||
|
||||
users_json_by_app_id: dict[str, Any] = {}
|
||||
for start_index in range(0, len(ordered_accessible_app_ids), WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE):
|
||||
app_id_batch = ordered_accessible_app_ids[
|
||||
start_index : start_index + WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE
|
||||
]
|
||||
pipe = redis_client.pipeline(transaction=False)
|
||||
for app_id in app_id_batch:
|
||||
pipe.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
|
||||
users_json_batch = pipe.execute()
|
||||
for app_id, users_json in zip(app_id_batch, users_json_batch):
|
||||
users_json_by_app_id[app_id] = users_json
|
||||
|
||||
results = []
|
||||
for app_id in app_ids:
|
||||
if app_id not in accessible_app_ids:
|
||||
continue
|
||||
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
for app_id in ordered_accessible_app_ids:
|
||||
users_json = users_json_by_app_id.get(app_id, {})
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
|
||||
@ -75,14 +75,15 @@ console_ns.schema_model(
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
match value:
|
||||
case FileSegment():
|
||||
return value.value.model_dump()
|
||||
case ArrayFileSegment():
|
||||
return [i.model_dump() for i in value.value]
|
||||
case SegmentGroup():
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
case _:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
@ -32,12 +32,7 @@ class TagBindingPayload(BaseModel):
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagBindingItemDeletePayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1)
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
@ -75,7 +70,6 @@ register_schema_models(
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagBindingItemDeletePayload,
|
||||
TagListQueryParam,
|
||||
TagResponse,
|
||||
)
|
||||
@ -184,13 +178,13 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_binding() -> tuple[dict[str, str], int]:
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_id=payload.tag_id,
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
@ -211,54 +205,15 @@ class TagBindingCollectionApi(Resource):
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/<uuid:id>")
|
||||
class TagBindingItemApi(Resource):
|
||||
"""Canonical item resource for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("delete_tag_binding")
|
||||
@console_ns.doc(params={"id": "Tag ID"})
|
||||
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
_require_tag_binding_edit_permission()
|
||||
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_id=str(id),
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class DeprecatedTagBindingCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class DeprecatedTagBindingRemoveApi(Resource):
|
||||
"""Deprecated verb-based alias for tag binding deletion."""
|
||||
class TagBindingRemoveApi(Resource):
|
||||
"""Batch resource for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("delete_tag_binding_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
|
||||
@console_ns.doc("remove_tag_bindings")
|
||||
@console_ns.doc(description="Remove one or more tag bindings from a target.")
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _remove_tag_binding()
|
||||
return _remove_tag_bindings()
|
||||
|
||||
@ -8,6 +8,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -45,6 +46,8 @@ from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
@ -322,9 +325,24 @@ class AccountAvatarApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
avatar = args.avatar
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return {"avatar_url": avatar}
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
|
||||
if upload_file is None:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
if upload_file.tenant_id != current_tenant_id:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
if upload_file.created_by_role != CreatorUserRole.ACCOUNT or upload_file.created_by != current_user.id:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel):
|
||||
|
||||
|
||||
class TagUnbindingPayload(BaseModel):
|
||||
tag_id: str
|
||||
"""Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally."""
|
||||
|
||||
tag_ids: list[str] = Field(default_factory=list)
|
||||
tag_id: str | None = None
|
||||
target_id: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_legacy_tag_id(cls, data: object) -> object:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if not data.get("tag_ids") and data.get("tag_id"):
|
||||
return {**data, "tag_ids": [data["tag_id"]]}
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tag_ids(self) -> "TagUnbindingPayload":
|
||||
if not self.tag_ids:
|
||||
raise ValueError("Tag IDs is required.")
|
||||
return self
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
|
||||
@service_api_ns.doc("unbind_dataset_tag")
|
||||
@service_api_ns.doc(description="Unbind a tag from a dataset")
|
||||
@service_api_ns.doc("unbind_dataset_tags")
|
||||
@service_api_ns.doc(description="Unbind tags from a dataset")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Tag unbound successfully",
|
||||
204: "Tags unbound successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -468,15 +468,98 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from an uploaded file for canonical and deprecated routes."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args: dict[str, object] = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Deprecated resource aliases for file document updates."""
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc("update_document_by_file_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by uploading a file. "
|
||||
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
@ -487,82 +570,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file through the deprecated file-update aliases."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
@ -876,6 +886,22 @@ class DocumentApi(DatasetApiResource):
|
||||
|
||||
return response
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file on the canonical document resource."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
@service_api_ns.doc("delete_document")
|
||||
@service_api_ns.doc(description="Delete a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
||||
@ -23,7 +23,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App
|
||||
from models.model import App, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -69,12 +69,12 @@ class AudioApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert audio to text"""
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.external_user_id)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
@ -151,6 +151,12 @@ def deserialize_response(raw_data: bytes) -> Response:
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
# Replace Flask's default headers (e.g. Content-Type, Content-Length) with the
|
||||
# parsed ones so we faithfully reproduce the original response. Use Headers.add
|
||||
# rather than dict-style assignment so that repeated headers such as Set-Cookie
|
||||
# (and any other multi-valued header per RFC 9110) are preserved instead of
|
||||
# being overwritten.
|
||||
response.headers.clear()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
@ -158,6 +164,6 @@ def deserialize_response(raw_data: bytes) -> Response:
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
response.headers.add(name, value.strip())
|
||||
|
||||
return response
|
||||
|
||||
@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
@ -445,7 +445,7 @@ class ProviderManager:
|
||||
@staticmethod
|
||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||
providers = session.scalars(stmt)
|
||||
for provider in providers:
|
||||
@ -462,7 +462,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
provider_models = session.scalars(stmt)
|
||||
for provider_model in provider_models:
|
||||
@ -478,7 +478,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_preferred_provider_type_records_dict = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
preferred_provider_types = session.scalars(stmt)
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
@ -496,7 +496,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||
provider_model_settings = session.scalars(stmt)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
@ -514,7 +514,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_credentials_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
|
||||
provider_model_credentials = session.scalars(stmt)
|
||||
for provider_model_credential in provider_model_credentials:
|
||||
@ -544,7 +544,7 @@ class ProviderManager:
|
||||
return {}
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||
provider_load_balancing_configs = session.scalars(stmt)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
@ -578,7 +578,7 @@ class ProviderManager:
|
||||
:param provider_name: provider name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(
|
||||
@ -608,7 +608,7 @@ class ProviderManager:
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = (
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
|
||||
@ -217,10 +217,11 @@ class RetrievalService:
|
||||
"""Deduplicate documents in O(n) while preserving first-seen order.
|
||||
|
||||
Rules:
|
||||
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
|
||||
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
|
||||
- For non-dify documents (or dify without doc_id): deduplicate by content key
|
||||
(provider, page_content), keeping the first occurrence.
|
||||
- If metadata["doc_id"] exists (any provider): deduplicate by (provider, doc_id) key;
|
||||
keep the doc with the highest metadata["score"] among duplicates. If a later duplicate
|
||||
has no score, ignore it.
|
||||
- If metadata["doc_id"] is absent: deduplicate by content key (provider, page_content),
|
||||
keeping the first occurrence.
|
||||
"""
|
||||
if not documents:
|
||||
return documents
|
||||
@ -231,11 +232,10 @@ class RetrievalService:
|
||||
order: list[tuple] = []
|
||||
|
||||
for doc in documents:
|
||||
is_dify = doc.provider == "dify"
|
||||
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
|
||||
doc_id = (doc.metadata or {}).get("doc_id")
|
||||
|
||||
if is_dify and doc_id:
|
||||
key = ("dify", doc_id)
|
||||
if doc_id:
|
||||
key = (doc.provider or "dify", doc_id)
|
||||
if key not in chosen:
|
||||
chosen[key] = doc
|
||||
order.append(key)
|
||||
|
||||
@ -144,8 +144,20 @@ class Vector:
|
||||
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
return get_vector_factory_class(vector_type)
|
||||
|
||||
@staticmethod
|
||||
def _filter_empty_text_documents(documents: list[Document]) -> list[Document]:
|
||||
filtered_documents = [document for document in documents if document.page_content.strip()]
|
||||
skipped_count = len(documents) - len(filtered_documents)
|
||||
if skipped_count:
|
||||
logger.warning("skip %d empty documents before vector embedding", skipped_count)
|
||||
return filtered_documents
|
||||
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
if texts:
|
||||
texts = self._filter_empty_text_documents(texts)
|
||||
if not texts:
|
||||
return
|
||||
|
||||
start = time.time()
|
||||
logger.info("start embedding %s texts %s", len(texts), start)
|
||||
batch_size = 1000
|
||||
@ -203,8 +215,14 @@ class Vector:
|
||||
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
documents = self._filter_empty_text_documents(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
|
||||
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
|
||||
|
||||
@ -1078,6 +1078,13 @@ class ToolManager:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
if variable_pool:
|
||||
config = tool_configurations.get(parameter.name, {})
|
||||
|
||||
selector_value = cls._extract_runtime_selector_value(parameter, config)
|
||||
if selector_value is not None:
|
||||
# Selector parameters carry structured dictionaries, not scalar ToolInput values.
|
||||
runtime_parameters[parameter.name] = selector_value
|
||||
continue
|
||||
|
||||
if not (config and isinstance(config, dict) and config.get("value") is not None):
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
@ -1105,5 +1112,39 @@ class ToolManager:
|
||||
runtime_parameters[parameter.name] = value
|
||||
return runtime_parameters
|
||||
|
||||
@classmethod
|
||||
def _extract_runtime_selector_value(cls, parameter: ToolParameter, config: Any) -> dict[str, Any] | None:
|
||||
if parameter.type not in {
|
||||
ToolParameter.ToolParameterType.MODEL_SELECTOR,
|
||||
ToolParameter.ToolParameterType.APP_SELECTOR,
|
||||
}:
|
||||
return None
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
|
||||
input_value = config.get("value")
|
||||
if isinstance(input_value, dict) and cls._is_selector_value(parameter, input_value):
|
||||
return cast("dict[str, Any]", parameter.init_frontend_parameter(input_value))
|
||||
|
||||
if cls._is_selector_value(parameter, config):
|
||||
selector_value = dict(config)
|
||||
selector_value.pop("type", None)
|
||||
selector_value.pop("value", None)
|
||||
return cast("dict[str, Any]", parameter.init_frontend_parameter(selector_value))
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _is_selector_value(cls, parameter: ToolParameter, value: Mapping[str, Any]) -> bool:
|
||||
if parameter.type == ToolParameter.ToolParameterType.MODEL_SELECTOR:
|
||||
return (
|
||||
isinstance(value.get("provider"), str)
|
||||
and isinstance(value.get("model"), str)
|
||||
and isinstance(value.get("model_type"), str)
|
||||
)
|
||||
if parameter.type == ToolParameter.ToolParameterType.APP_SELECTOR:
|
||||
return isinstance(value.get("app_id"), str)
|
||||
return False
|
||||
|
||||
|
||||
ToolManager.load_hardcoded_providers_cache()
|
||||
|
||||
@ -272,6 +272,14 @@ def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, A
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
selector_value = _extract_selector_configuration(value)
|
||||
if selector_value is not None:
|
||||
# Model/app selectors are dictionaries even when they come through the legacy tool configuration path.
|
||||
# Move them to tool_parameters so graph validation does not flatten them as primitive constants.
|
||||
found_legacy_tool_inputs = True
|
||||
normalized_tool_parameters.setdefault(name, {"type": "constant", "value": selector_value})
|
||||
continue
|
||||
|
||||
input_type = value.get("type")
|
||||
input_value = value.get("value")
|
||||
if input_type not in {"mixed", "variable", "constant"}:
|
||||
@ -310,6 +318,28 @@ def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: An
|
||||
return None
|
||||
|
||||
|
||||
def _extract_selector_configuration(value: Mapping[str, Any]) -> dict[str, Any] | None:
|
||||
input_value = value.get("value")
|
||||
if isinstance(input_value, Mapping) and _is_selector_configuration(input_value):
|
||||
return dict(input_value)
|
||||
|
||||
if _is_selector_configuration(value):
|
||||
selector_value = dict(value)
|
||||
selector_value.pop("type", None)
|
||||
selector_value.pop("value", None)
|
||||
return selector_value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_selector_configuration(value: Mapping[str, Any]) -> bool:
|
||||
return (
|
||||
isinstance(value.get("provider"), str)
|
||||
and isinstance(value.get("model"), str)
|
||||
and isinstance(value.get("model_type"), str)
|
||||
) or isinstance(value.get("app_id"), str)
|
||||
|
||||
|
||||
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(recipients)
|
||||
|
||||
|
||||
@ -365,7 +365,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
adapted_node_config = adapt_node_config_for_graph(node_config)
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapted_node_config)
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
@ -373,6 +374,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
config_for_node_init: BaseNodeData | dict[str, Any]
|
||||
if isinstance(resolved_node_data, BaseNodeData):
|
||||
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
else:
|
||||
config_for_node_init = resolved_node_data
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@ -442,7 +448,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
config=config_for_node_init,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
@ -474,10 +480,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
include_retriever_attachment_loader: bool,
|
||||
include_jinja2_template_renderer: bool,
|
||||
) -> dict[str, object]:
|
||||
validated_node_data = cast(
|
||||
LLMCompatibleNodeData,
|
||||
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
|
||||
)
|
||||
validated_node_data = cast(LLMCompatibleNodeData, node_data)
|
||||
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
|
||||
node_init_kwargs: dict[str, object] = {
|
||||
"credentials_provider": self._llm_credentials_provider,
|
||||
|
||||
@ -501,11 +501,15 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec:
|
||||
tool_configurations = dict(node_data.tool_configurations)
|
||||
tool_configurations.update(
|
||||
{name: tool_input.model_dump(mode="python") for name, tool_input in node_data.tool_parameters.items()}
|
||||
)
|
||||
return _WorkflowToolRuntimeSpec(
|
||||
provider_type=CoreToolProviderType(node_data.provider_type.value),
|
||||
provider_id=node_data.provider_id,
|
||||
tool_name=node_data.tool_name,
|
||||
tool_configurations=dict(node_data.tool_configurations),
|
||||
tool_configurations=tool_configurations,
|
||||
credential_id=node_data.credential_id,
|
||||
)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from events.app_event import app_draft_workflow_was_synced
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.nodes.tool.entities import ToolEntity
|
||||
@ -19,7 +20,8 @@ def handle(sender, **kwargs):
|
||||
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
|
||||
if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL:
|
||||
try:
|
||||
tool_entity = ToolEntity.model_validate(node_data["data"])
|
||||
adapted_node_data = adapt_node_config_for_graph(node_data)
|
||||
tool_entity = ToolEntity.model_validate(adapted_node_data["data"])
|
||||
provider_type = ToolProviderType(tool_entity.provider_type.value)
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=provider_type,
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from flask import Flask
|
||||
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
def init_app(app):
|
||||
def init_app(app: Flask):
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
|
||||
@ -298,7 +298,7 @@ def _build_from_datasource_file(
|
||||
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||
|
||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||
detected_file_type = standardize_file_type(extension=extension, mime_type=datasource_file.mime_type)
|
||||
file_type = _resolve_file_type(
|
||||
detected_file_type=detected_file_type,
|
||||
specified_type=mapping.get("type"),
|
||||
|
||||
@ -19,8 +19,13 @@ from werkzeug.http import parse_options_header
|
||||
from core.helper import ssrf_proxy
|
||||
|
||||
|
||||
def extract_filename(url_path: str, content_disposition: str | None) -> str | None:
|
||||
"""Extract a safe filename from Content-Disposition or the request URL path."""
|
||||
def extract_filename(url_or_path: str, content_disposition: str | None) -> str | None:
|
||||
"""Extract a safe filename from Content-Disposition or the request URL path.
|
||||
|
||||
Handles full URLs, paths with query strings, hash fragments, and percent-encoded segments.
|
||||
Query strings and hash fragments are stripped from the URL before extracting the basename.
|
||||
Percent-encoded characters in the path are decoded safely.
|
||||
"""
|
||||
filename: str | None = None
|
||||
if content_disposition:
|
||||
filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition)
|
||||
@ -47,8 +52,13 @@ def extract_filename(url_path: str, content_disposition: str | None) -> str | No
|
||||
filename = urllib.parse.unquote(raw)
|
||||
|
||||
if not filename:
|
||||
candidate = os.path.basename(url_path)
|
||||
filename = urllib.parse.unquote(candidate) if candidate else None
|
||||
# Parse the URL to extract just the path, stripping query strings and fragments
|
||||
# This handles both full URLs and bare paths
|
||||
parsed = urllib.parse.urlparse(url_or_path)
|
||||
path = parsed.path
|
||||
candidate = os.path.basename(path)
|
||||
# Decode percent-encoded characters, with safe fallback for malformed input
|
||||
filename = urllib.parse.unquote(candidate, errors="replace") if candidate else None
|
||||
|
||||
if filename:
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
from typing import TypeGuard
|
||||
|
||||
|
||||
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
|
||||
return isinstance(v, dict)
|
||||
|
||||
|
||||
def is_str(v: object) -> TypeGuard[str]:
|
||||
return isinstance(v, str)
|
||||
@ -9,11 +9,11 @@ import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
@ -82,7 +82,8 @@ class Provider(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
with session_factory.create_session() as session:
|
||||
return session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
@ -145,9 +146,10 @@ class ProviderModel(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
return session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
)
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
|
||||
@ -1568,12 +1568,14 @@ class WorkflowDraftVariable(Base):
|
||||
),
|
||||
)
|
||||
|
||||
# Relationship to WorkflowDraftVariableFile
|
||||
# WorkflowDraftVariableFile uses TypeBase while WorkflowDraftVariable uses Base, so the relationship
|
||||
# must resolve the class object lazily instead of relying on string lookup across registries.
|
||||
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
|
||||
lambda: WorkflowDraftVariableFile,
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowDraftVariableFile.id == WorkflowDraftVariable.file_id",
|
||||
primaryjoin=lambda: orm.foreign(WorkflowDraftVariable.file_id) == WorkflowDraftVariableFile.id,
|
||||
)
|
||||
|
||||
# Cache for deserialized value
|
||||
@ -1892,7 +1894,7 @@ class WorkflowDraftVariable(Base):
|
||||
return self.last_edited_at is not None
|
||||
|
||||
|
||||
class WorkflowDraftVariableFile(Base):
|
||||
class WorkflowDraftVariableFile(TypeBase):
|
||||
"""Stores metadata about files associated with large workflow draft variables.
|
||||
|
||||
This model acts as an intermediary between WorkflowDraftVariable and UploadFile,
|
||||
@ -1906,18 +1908,7 @@ class WorkflowDraftVariableFile(Base):
|
||||
__tablename__ = "workflow_draft_variable_files"
|
||||
|
||||
# Primary key
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
default=lambda: str(uuidv7()),
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default_factory=lambda: str(uuidv7()), init=False)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
@ -1969,15 +1960,23 @@ class WorkflowDraftVariableFile(Base):
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationship to UploadFile
|
||||
# Rows are created with `upload_file_id`; callers should load this relationship explicitly when needed.
|
||||
upload_file: Mapped["UploadFile"] = orm.relationship(
|
||||
UploadFile,
|
||||
foreign_keys=[upload_file_id],
|
||||
lazy="raise",
|
||||
init=False,
|
||||
uselist=False,
|
||||
primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id,
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default_factory=naive_utc_now,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
|
||||
def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
||||
@ -246,8 +246,18 @@ class TidbService:
|
||||
userPrefix = item["userPrefix"]
|
||||
if state == "ACTIVE" and len(userPrefix) > 0:
|
||||
cluster_info = tidb_serverless_list_map[item["clusterId"]]
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
cluster_info.account = f"{userPrefix}.root"
|
||||
if not cluster_info.qdrant_endpoint:
|
||||
cluster_info.qdrant_endpoint = TidbService.extract_qdrant_endpoint(
|
||||
item
|
||||
) or TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
|
||||
if cluster_info.qdrant_endpoint:
|
||||
cluster_info.status = TidbAuthBindingStatus.ACTIVE
|
||||
else:
|
||||
logger.warning(
|
||||
"Cluster %s is ACTIVE but qdrant endpoint is not ready; will retry later",
|
||||
item["clusterId"],
|
||||
)
|
||||
db.session.add(cluster_info)
|
||||
db.session.commit()
|
||||
else:
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
|
||||
|
||||
from models.enums import TidbAuthBindingStatus
|
||||
|
||||
|
||||
class TestExtractQdrantEndpoint:
|
||||
"""Unit tests for TidbService.extract_qdrant_endpoint."""
|
||||
@ -216,3 +219,86 @@ class TestBatchCreateEdgeCases:
|
||||
private_key="priv",
|
||||
region="us-east-1",
|
||||
)
|
||||
|
||||
|
||||
class TestBatchUpdateTidbServerlessClusterStatus:
|
||||
"""Verify that status updates only expose clusters after qdrant endpoint is ready."""
|
||||
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
def test_sets_active_when_batch_response_contains_endpoint(self, mock_http, mock_db):
|
||||
binding = SimpleNamespace(
|
||||
cluster_id="c-1",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
account="root",
|
||||
qdrant_endpoint=None,
|
||||
)
|
||||
mock_http.get.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {
|
||||
"clusters": [
|
||||
{
|
||||
"clusterId": "c-1",
|
||||
"state": "ACTIVE",
|
||||
"userPrefix": "pfx",
|
||||
"endpoints": {"public": {"host": "gw.tidbcloud.com"}},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
|
||||
|
||||
assert binding.account == "pfx.root"
|
||||
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
|
||||
assert binding.status == TidbAuthBindingStatus.ACTIVE
|
||||
mock_db.session.add.assert_called_once_with(binding)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
def test_fetches_endpoint_when_batch_response_omits_it(self, mock_http, mock_db, mock_fetch_endpoint):
|
||||
binding = SimpleNamespace(
|
||||
cluster_id="c-1",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
account="root",
|
||||
qdrant_endpoint=None,
|
||||
)
|
||||
mock_http.get.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]},
|
||||
)
|
||||
|
||||
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
|
||||
|
||||
assert binding.account == "pfx.root"
|
||||
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
|
||||
assert binding.status == TidbAuthBindingStatus.ACTIVE
|
||||
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
|
||||
mock_db.session.add.assert_called_once_with(binding)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
|
||||
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
|
||||
def test_keeps_creating_when_endpoint_is_not_ready(self, mock_http, mock_db, mock_fetch_endpoint):
|
||||
binding = SimpleNamespace(
|
||||
cluster_id="c-1",
|
||||
status=TidbAuthBindingStatus.CREATING,
|
||||
account="root",
|
||||
qdrant_endpoint=None,
|
||||
)
|
||||
mock_http.get.return_value = MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]},
|
||||
)
|
||||
|
||||
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
|
||||
|
||||
assert binding.account == "pfx.root"
|
||||
assert binding.qdrant_endpoint is None
|
||||
assert binding.status == TidbAuthBindingStatus.CREATING
|
||||
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
|
||||
mock_db.session.add.assert_called_once_with(binding)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.3"
|
||||
version = "1.14.0"
|
||||
requires-python = "~=3.12.0"
|
||||
|
||||
dependencies = [
|
||||
# Legacy: mature and widely deployed
|
||||
"bleach>=6.3.0",
|
||||
"boto3>=1.42.96",
|
||||
"boto3>=1.43.3",
|
||||
"celery>=5.6.3",
|
||||
"croniter>=6.2.2",
|
||||
"flask>=3.1.3,<4.0.0",
|
||||
@ -14,7 +14,7 @@ dependencies = [
|
||||
"gevent>=26.4.0",
|
||||
"gevent-websocket>=0.10.1",
|
||||
"gmpy2>=2.3.0",
|
||||
"google-api-python-client>=2.194.0",
|
||||
"google-api-python-client>=2.195.0",
|
||||
"gunicorn>=25.3.0",
|
||||
"psycogreen>=1.0.2",
|
||||
"psycopg2-binary>=2.9.12",
|
||||
@ -31,7 +31,7 @@ dependencies = [
|
||||
"flask-migrate>=4.1.0,<5.0.0",
|
||||
"flask-orjson>=2.0.0,<3.0.0",
|
||||
"flask-restx>=1.3.2,<2.0.0",
|
||||
"google-cloud-aiplatform>=1.148.1,<2.0.0",
|
||||
"google-cloud-aiplatform>=1.149.0,<2.0.0",
|
||||
"httpx[socks]>=0.28.1,<1.0.0",
|
||||
"opentelemetry-distro>=0.62b1,<1.0.0",
|
||||
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
|
||||
@ -127,7 +127,7 @@ dev = [
|
||||
"testcontainers>=4.14.2",
|
||||
"types-aiofiles>=25.1.0",
|
||||
"types-beautifulsoup4>=4.12.0",
|
||||
"types-cachetools>=6.2.0",
|
||||
"types-cachetools>=7.0.0.20260503",
|
||||
"types-colorama>=0.4.15",
|
||||
"types-defusedxml>=0.7.0",
|
||||
"types-deprecated>=1.3.1",
|
||||
@ -135,7 +135,7 @@ dev = [
|
||||
"types-flask-cors>=6.0.0",
|
||||
"types-flask-migrate>=4.1.0",
|
||||
"types-gevent>=26.4.0",
|
||||
"types-greenlet>=3.4.0",
|
||||
"types-greenlet>=3.5.0.20260428",
|
||||
"types-html5lib>=1.1.11",
|
||||
"types-markdown>=3.10.2",
|
||||
"types-oauthlib>=3.3.0",
|
||||
@ -143,7 +143,7 @@ dev = [
|
||||
"types-olefile>=0.47.0",
|
||||
"types-openpyxl>=3.1.5",
|
||||
"types-pexpect>=4.9.0",
|
||||
"types-protobuf>=7.34.1",
|
||||
"types-protobuf>=7.34.1.20260503",
|
||||
"types-psutil>=7.2.2",
|
||||
"types-psycopg2>=2.9.21.20260422",
|
||||
"types-pygments>=2.20.0",
|
||||
@ -158,11 +158,11 @@ dev = [
|
||||
"types-tensorflow>=2.18.0.20260408",
|
||||
"types-tqdm>=4.67.3.20260408",
|
||||
"types-ujson>=5.10.0",
|
||||
"boto3-stubs>=1.42.96",
|
||||
"boto3-stubs>=1.43.2",
|
||||
"types-jmespath>=1.1.0.20260408",
|
||||
"hypothesis>=6.152.3",
|
||||
"hypothesis>=6.152.4",
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=2.0.0.20260408",
|
||||
"types_cffi>=2.0.0.20260429",
|
||||
"types_setuptools>=82.0.0.20260408",
|
||||
"pandas-stubs>=3.0.0",
|
||||
"scipy-stubs>=1.17.1.4",
|
||||
@ -174,7 +174,7 @@ dev = [
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.62.0",
|
||||
"pyrefly>=0.64.0",
|
||||
"xinference-client>=2.7.0",
|
||||
]
|
||||
|
||||
@ -184,7 +184,7 @@ dev = [
|
||||
############################################################
|
||||
storage = [
|
||||
"azure-storage-blob>=12.28.0",
|
||||
"bce-python-sdk>=0.9.70",
|
||||
"bce-python-sdk>=0.9.71",
|
||||
"cos-python-sdk-v5>=1.9.42",
|
||||
"esdk-obs-python>=3.22.2",
|
||||
"google-cloud-storage>=3.10.1",
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.entities import Rule
|
||||
from core.rag.entities.metadata_entities import MetadataFilteringCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
@ -83,6 +84,7 @@ class RetrievalModel(BaseModel):
|
||||
score_threshold_enabled: bool
|
||||
score_threshold: float | None = None
|
||||
weights: WeightModel | None = None
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None
|
||||
|
||||
|
||||
class MetaDataConfig(BaseModel):
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -29,7 +31,7 @@ class TagBindingCreatePayload(BaseModel):
|
||||
|
||||
|
||||
class TagBindingDeletePayload(BaseModel):
|
||||
tag_id: str
|
||||
tag_ids: list[str] = Field(min_length=1)
|
||||
target_id: str
|
||||
type: TagType
|
||||
|
||||
@ -178,13 +180,18 @@ class TagService:
|
||||
@staticmethod
|
||||
def delete_tag_binding(payload: TagBindingDeletePayload):
|
||||
TagService.check_target_exists(payload.type, payload.target_id)
|
||||
tag_binding = db.session.scalar(
|
||||
select(TagBinding)
|
||||
.where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id)
|
||||
.limit(1)
|
||||
result = cast(
|
||||
CursorResult,
|
||||
db.session.execute(
|
||||
delete(TagBinding).where(
|
||||
TagBinding.target_id == payload.target_id,
|
||||
TagBinding.tag_id.in_(payload.tag_ids),
|
||||
TagBinding.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
),
|
||||
)
|
||||
if tag_binding:
|
||||
db.session.delete(tag_binding)
|
||||
|
||||
if result.rowcount:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1083,10 +1083,9 @@ class DraftVariableSaver:
|
||||
mimetype=content_type,
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
assert self._user.current_tenant_id
|
||||
# Create WorkflowDraftVariableFile record
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
id=uuidv7(),
|
||||
upload_file_id=upload_file.id,
|
||||
size=original_size,
|
||||
length=original_length,
|
||||
@ -1095,6 +1094,7 @@ class DraftVariableSaver:
|
||||
tenant_id=self._user.current_tenant_id,
|
||||
user_id=self._user.id,
|
||||
)
|
||||
variable_file.id = str(uuidv7())
|
||||
engine = bind = self._session.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
with sessionmaker(bind=engine, expire_on_commit=False).begin() as session:
|
||||
|
||||
@ -433,7 +433,7 @@ def flask_app_with_containers(set_up_containers_and_env) -> Flask:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]:
|
||||
def flask_req_ctx_with_containers(flask_app_with_containers: Flask) -> Generator[None, None, None]:
|
||||
"""
|
||||
Request context fixture for containerized Flask application.
|
||||
|
||||
@ -454,7 +454,7 @@ def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None,
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]:
|
||||
def test_client_with_containers(flask_app_with_containers: Flask) -> Generator[FlaskClient, None, None]:
|
||||
"""
|
||||
Test client fixture for containerized Flask application.
|
||||
|
||||
@ -475,7 +475,7 @@ def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskCli
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]:
|
||||
def db_session_with_containers(flask_app_with_containers: Flask) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Database session fixture for containerized testing.
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
@ -69,7 +70,7 @@ def _unwrap(func):
|
||||
|
||||
class TestCompletionEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_completion_create_payload(self):
|
||||
@ -86,7 +87,7 @@ class TestCompletionEndpoints:
|
||||
)
|
||||
assert payload.query == "hi"
|
||||
|
||||
def test_completion_api_success(self, app, monkeypatch):
|
||||
def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -116,7 +117,7 @@ class TestCompletionEndpoints:
|
||||
|
||||
assert resp == {"result": {"text": "ok"}}
|
||||
|
||||
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
|
||||
def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -142,7 +143,7 @@ class TestCompletionEndpoints:
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
|
||||
def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -166,7 +167,7 @@ class TestCompletionEndpoints:
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_quota_exceeded(self, app, monkeypatch):
|
||||
def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -193,10 +194,10 @@ class TestCompletionEndpoints:
|
||||
|
||||
class TestAppEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch):
|
||||
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = app_module.AppApi()
|
||||
method = _unwrap(api.put)
|
||||
payload = {
|
||||
@ -234,7 +235,7 @@ class TestAppEndpoints:
|
||||
}
|
||||
)
|
||||
|
||||
def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch):
|
||||
def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = app_module.AppIconApi()
|
||||
method = _unwrap(api.post)
|
||||
payload = {
|
||||
@ -266,7 +267,7 @@ class TestAppEndpoints:
|
||||
|
||||
class TestOpsTraceEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_ops_trace_query_basic(self):
|
||||
@ -277,7 +278,7 @@ class TestOpsTraceEndpoints:
|
||||
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
|
||||
assert payload.tracing_config["api_key"] == "k"
|
||||
|
||||
def test_trace_app_config_get_empty(self, app, monkeypatch):
|
||||
def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
@ -292,7 +293,7 @@ class TestOpsTraceEndpoints:
|
||||
|
||||
assert result == {"has_not_configured": True}
|
||||
|
||||
def test_trace_app_config_post_invalid(self, app, monkeypatch):
|
||||
def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -309,7 +310,7 @@ class TestOpsTraceEndpoints:
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
|
||||
def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
@ -326,7 +327,7 @@ class TestOpsTraceEndpoints:
|
||||
|
||||
class TestSiteEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_site_response_structure(self):
|
||||
@ -337,7 +338,7 @@ class TestSiteEndpoints:
|
||||
payload = AppSiteUpdatePayload(default_language="en-US")
|
||||
assert payload.default_language == "en-US"
|
||||
|
||||
def test_app_site_update_post(self, app, monkeypatch):
|
||||
def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = site_module.AppSite()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -375,7 +376,7 @@ class TestSiteEndpoints:
|
||||
assert isinstance(result, dict)
|
||||
assert result["title"] == "My Site"
|
||||
|
||||
def test_app_site_access_token_reset(self, app, monkeypatch):
|
||||
def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = site_module.AppSiteAccessTokenReset()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -427,7 +428,7 @@ class TestWorkflowEndpoints:
|
||||
|
||||
class TestWorkflowAppLogEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_app_log_query(self):
|
||||
@ -438,7 +439,7 @@ class TestWorkflowAppLogEndpoints:
|
||||
query = WorkflowAppLogQuery(detail="true")
|
||||
assert query.detail is True
|
||||
|
||||
def test_workflow_app_log_api_get(self, app, monkeypatch):
|
||||
def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = workflow_app_log_module.WorkflowAppLogApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
@ -477,14 +478,14 @@ class TestWorkflowAppLogEndpoints:
|
||||
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_variable_creation(self):
|
||||
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
|
||||
assert payload.name == "var1"
|
||||
|
||||
def test_workflow_variable_collection_get(self, app, monkeypatch):
|
||||
def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
@ -529,7 +530,7 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
|
||||
class TestWorkflowStatisticEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_statistic_time_range(self):
|
||||
@ -541,7 +542,7 @@ class TestWorkflowStatisticEndpoints:
|
||||
assert query.start is None
|
||||
assert query.end is None
|
||||
|
||||
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
|
||||
def test_workflow_daily_runs_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
@ -567,7 +568,7 @@ class TestWorkflowStatisticEndpoints:
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
|
||||
|
||||
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
|
||||
def test_workflow_daily_terminals_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
@ -598,7 +599,7 @@ class TestWorkflowStatisticEndpoints:
|
||||
|
||||
class TestWorkflowTriggerEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_webhook_trigger_payload(self):
|
||||
@ -608,7 +609,7 @@ class TestWorkflowTriggerEndpoints:
|
||||
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
|
||||
assert enable_payload.enable_trigger is True
|
||||
|
||||
def test_webhook_trigger_api_get(self, app, monkeypatch):
|
||||
def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = workflow_trigger_module.WebhookTriggerApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
@ -36,10 +37,10 @@ def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
|
||||
class TestAppImportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -57,7 +58,7 @@ class TestAppImportApi:
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -75,7 +76,7 @@ class TestAppImportApi:
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -96,7 +97,7 @@ class TestAppImportApi:
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -121,7 +122,7 @@ class TestAppImportApi:
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -149,10 +150,10 @@ class TestAppImportApi:
|
||||
|
||||
class TestAppImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -172,10 +173,10 @@ class TestAppImportConfirmApi:
|
||||
|
||||
class TestAppImportCheckDependenciesApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.email_register import (
|
||||
EmailRegisterCheckApi,
|
||||
@ -16,7 +17,7 @@ from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(flask_app_with_containers):
|
||||
def app(flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
@ -33,7 +34,7 @@ class TestEmailRegisterSendEmailApi:
|
||||
mock_is_freeze,
|
||||
mock_send_mail,
|
||||
mock_get_account,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_is_freeze.return_value = False
|
||||
@ -75,7 +76,7 @@ class TestEmailRegisterCheckApi:
|
||||
mock_revoke,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
|
||||
@ -120,7 +121,7 @@ class TestEmailRegisterResetApi:
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
|
||||
@ -6,6 +6,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
@ -16,7 +17,7 @@ from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(flask_app_with_containers):
|
||||
def app(flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
@ -31,7 +32,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
@ -80,7 +81,7 @@ class TestForgotPasswordCheckApi:
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
|
||||
@ -123,7 +124,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_db,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
||||
mock_account = MagicMock()
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import (
|
||||
OAuthCallback,
|
||||
@ -21,7 +22,7 @@ from services.errors.account import AccountRegisterError
|
||||
|
||||
class TestGetOAuthProviders:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -65,7 +66,7 @@ class TestOAuthLogin:
|
||||
return OAuthLogin()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
@ -89,7 +90,7 @@ class TestOAuthLogin:
|
||||
mock_redirect,
|
||||
mock_get_providers,
|
||||
resource,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_oauth_provider,
|
||||
invite_token,
|
||||
expected_token,
|
||||
@ -130,7 +131,7 @@ class TestOAuthCallback:
|
||||
return OAuthCallback()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
@ -164,7 +165,7 @@ class TestOAuthCallback:
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
app: Flask,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
@ -217,7 +218,7 @@ class TestOAuthCallback:
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
app: Flask,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
@ -261,7 +262,7 @@ class TestOAuthCallback:
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
resource,
|
||||
app,
|
||||
app: Flask,
|
||||
oauth_setup,
|
||||
account_status,
|
||||
expected_redirect,
|
||||
@ -300,7 +301,7 @@ class TestOAuthCallback:
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
app: Flask,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
@ -336,7 +337,7 @@ class TestOAuthCallback:
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
app: Flask,
|
||||
oauth_setup,
|
||||
):
|
||||
"""Defensive test for CLOSED account status handling in OAuth callback.
|
||||
@ -394,7 +395,7 @@ class TestOAuthCallback:
|
||||
|
||||
class TestAccountGeneration:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
@ -465,7 +466,7 @@ class TestAccountGeneration:
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
app: Flask,
|
||||
user_info,
|
||||
mock_account,
|
||||
allow_register,
|
||||
@ -504,7 +505,7 @@ class TestAccountGeneration:
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
@ -529,7 +530,7 @@ class TestAccountGeneration:
|
||||
mock_feature_service,
|
||||
mock_tenant_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
app: Flask,
|
||||
user_info,
|
||||
mock_account,
|
||||
):
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
@ -25,7 +26,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
@ -46,7 +47,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_account,
|
||||
):
|
||||
# Arrange
|
||||
@ -68,7 +69,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_send_email.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app):
|
||||
def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app: Flask):
|
||||
"""
|
||||
Test password reset email blocked by IP rate limit.
|
||||
|
||||
@ -104,7 +105,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_account,
|
||||
language_input,
|
||||
expected_language,
|
||||
@ -138,7 +139,7 @@ class TestForgotPasswordCheckApi:
|
||||
"""Test cases for verifying password reset codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@ -153,7 +154,7 @@ class TestForgotPasswordCheckApi:
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
"""
|
||||
Test successful verification code validation.
|
||||
@ -200,7 +201,7 @@ class TestForgotPasswordCheckApi:
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
|
||||
@ -221,7 +222,7 @@ class TestForgotPasswordCheckApi:
|
||||
mock_reset_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, app):
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, app: Flask):
|
||||
"""
|
||||
Test code verification blocked by rate limit.
|
||||
|
||||
@ -244,7 +245,7 @@ class TestForgotPasswordCheckApi:
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app):
|
||||
def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app: Flask):
|
||||
"""
|
||||
Test code verification with invalid token.
|
||||
|
||||
@ -267,7 +268,7 @@ class TestForgotPasswordCheckApi:
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app):
|
||||
def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app: Flask):
|
||||
"""
|
||||
Test code verification with mismatched email.
|
||||
|
||||
@ -292,7 +293,7 @@ class TestForgotPasswordCheckApi:
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app):
|
||||
def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app: Flask):
|
||||
"""
|
||||
Test code verification with incorrect code.
|
||||
|
||||
@ -321,7 +322,7 @@ class TestForgotPasswordResetApi:
|
||||
"""Test cases for resetting password with verified token."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
@ -344,7 +345,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_account,
|
||||
):
|
||||
"""
|
||||
@ -375,7 +376,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_revoke_token.assert_called_once_with("valid_token")
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_mismatch(self, mock_get_data, app):
|
||||
def test_reset_password_mismatch(self, mock_get_data, app: Flask):
|
||||
"""
|
||||
Test password reset with mismatched passwords.
|
||||
|
||||
@ -397,7 +398,7 @@ class TestForgotPasswordResetApi:
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_invalid_token(self, mock_get_data, app):
|
||||
def test_reset_password_invalid_token(self, mock_get_data, app: Flask):
|
||||
"""
|
||||
Test password reset with invalid token.
|
||||
|
||||
@ -418,7 +419,7 @@ class TestForgotPasswordResetApi:
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, app):
|
||||
def test_reset_password_wrong_phase(self, mock_get_data, app: Flask):
|
||||
"""
|
||||
Test password reset with token not in reset phase.
|
||||
|
||||
@ -442,7 +443,7 @@ class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app):
|
||||
def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app: Flask):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -26,10 +27,10 @@ def unwrap(func):
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
def test_get_success(self, app: Flask):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -50,10 +51,10 @@ class TestPipelineTemplateListApi:
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
def test_get_success(self, app: Flask):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -74,7 +75,7 @@ class TestPipelineTemplateDetailApi:
|
||||
assert status == 200
|
||||
assert response == template
|
||||
|
||||
def test_get_returns_404_when_template_not_found(self, app):
|
||||
def test_get_returns_404_when_template_not_found(self, app: Flask):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -93,7 +94,7 @@ class TestPipelineTemplateDetailApi:
|
||||
assert status == 404
|
||||
assert "error" in response
|
||||
|
||||
def test_get_returns_404_for_customized_type_not_found(self, app):
|
||||
def test_get_returns_404_for_customized_type_not_found(self, app: Flask):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -115,10 +116,10 @@ class TestPipelineTemplateDetailApi:
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
def test_patch_success(self, app: Flask):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -140,7 +141,7 @@ class TestCustomizedPipelineTemplateApi:
|
||||
update_mock.assert_called_once()
|
||||
assert response == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
def test_delete_success(self, app: Flask):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@ -155,7 +156,7 @@ class TestCustomizedPipelineTemplateApi:
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app, db_session_with_containers: Session):
|
||||
def test_post_success(self, app: Flask, db_session_with_containers: Session):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -182,7 +183,7 @@ class TestCustomizedPipelineTemplateApi:
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
|
||||
def test_post_template_not_found(self, app):
|
||||
def test_post_template_not_found(self, app: Flask):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -193,10 +194,10 @@ class TestCustomizedPipelineTemplateApi:
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
def test_post_success(self, app: Flask):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
@ -24,13 +25,13 @@ def unwrap(func):
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
def test_post_success(self, app):
|
||||
def test_post_success(self, app: Flask):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -58,7 +59,7 @@ class TestCreateRagPipelineDatasetApi:
|
||||
assert status == 201
|
||||
assert response == import_info
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
def test_post_forbidden_non_editor(self, app: Flask):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -76,7 +77,7 @@ class TestCreateRagPipelineDatasetApi:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_dataset_name_duplicate(self, app):
|
||||
def test_post_dataset_name_duplicate(self, app: Flask):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -101,7 +102,7 @@ class TestCreateRagPipelineDatasetApi:
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
def test_post_invalid_payload(self, app: Flask):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -122,10 +123,10 @@ class TestCreateRagPipelineDatasetApi:
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
def test_post_success(self, app: Flask):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -152,7 +153,7 @@ class TestCreateEmptyRagPipelineDatasetApi:
|
||||
assert status == 201
|
||||
assert response == {"id": "ds-1"}
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
def test_post_forbidden_non_editor(self, app: Flask):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
@ -25,7 +26,7 @@ def unwrap(func):
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _payload(self, mode="create"):
|
||||
@ -35,7 +36,7 @@ class TestRagPipelineImportApi:
|
||||
"name": "Test",
|
||||
}
|
||||
|
||||
def test_post_success_200(self, app):
|
||||
def test_post_success_200(self, app: Flask):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -65,7 +66,7 @@ class TestRagPipelineImportApi:
|
||||
assert status == 200
|
||||
assert response == {"status": "success"}
|
||||
|
||||
def test_post_failed_400(self, app):
|
||||
def test_post_failed_400(self, app: Flask):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -95,7 +96,7 @@ class TestRagPipelineImportApi:
|
||||
assert status == 400
|
||||
assert response == {"status": "failed"}
|
||||
|
||||
def test_post_pending_202(self, app):
|
||||
def test_post_pending_202(self, app: Flask):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -128,10 +129,10 @@ class TestRagPipelineImportApi:
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_confirm_success(self, app):
|
||||
def test_confirm_success(self, app: Flask):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -159,7 +160,7 @@ class TestRagPipelineImportConfirmApi:
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
|
||||
def test_confirm_failed(self, app):
|
||||
def test_confirm_failed(self, app: Flask):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -190,10 +191,10 @@ class TestRagPipelineImportConfirmApi:
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
def test_get_success(self, app: Flask):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -219,10 +220,10 @@ class TestRagPipelineImportCheckDependenciesApi:
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_with_include_secret(self, app):
|
||||
def test_get_with_include_secret(self, app: Flask):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
|
||||
|
||||
@ -45,10 +46,10 @@ def unwrap(func):
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_draft_success(self, app):
|
||||
def test_get_draft_success(self, app: Flask):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -68,7 +69,7 @@ class TestDraftWorkflowApi:
|
||||
result = method(api, pipeline)
|
||||
assert result == workflow
|
||||
|
||||
def test_get_draft_not_exist(self, app):
|
||||
def test_get_draft_not_exist(self, app: Flask):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -86,7 +87,7 @@ class TestDraftWorkflowApi:
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_hash_not_match(self, app):
|
||||
def test_sync_hash_not_match(self, app: Flask):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -111,7 +112,7 @@ class TestDraftWorkflowApi:
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_invalid_text_plain(self, app):
|
||||
def test_sync_invalid_text_plain(self, app: Flask):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -128,7 +129,7 @@ class TestDraftWorkflowApi:
|
||||
response, status = method(api, pipeline)
|
||||
assert status == 400
|
||||
|
||||
def test_restore_published_workflow_to_draft_success(self, app):
|
||||
def test_restore_published_workflow_to_draft_success(self, app: Flask):
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -155,7 +156,7 @@ class TestDraftWorkflowApi:
|
||||
assert result["result"] == "success"
|
||||
assert result["hash"] == "restored-hash"
|
||||
|
||||
def test_restore_published_workflow_to_draft_not_found(self, app):
|
||||
def test_restore_published_workflow_to_draft_not_found(self, app: Flask):
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -179,7 +180,7 @@ class TestDraftWorkflowApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "published-workflow")
|
||||
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app):
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask):
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -211,10 +212,10 @@ class TestDraftWorkflowApi:
|
||||
|
||||
class TestDraftRunNodes:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_iteration_node_success(self, app):
|
||||
def test_iteration_node_success(self, app: Flask):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -240,7 +241,7 @@ class TestDraftRunNodes:
|
||||
result = method(api, pipeline, "node")
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_iteration_node_conversation_not_exists(self, app):
|
||||
def test_iteration_node_conversation_not_exists(self, app: Flask):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -262,7 +263,7 @@ class TestDraftRunNodes:
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
def test_loop_node_success(self, app):
|
||||
def test_loop_node_success(self, app: Flask):
|
||||
api = RagPipelineDraftRunLoopNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -290,10 +291,10 @@ class TestDraftRunNodes:
|
||||
|
||||
class TestPipelineRunApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_draft_run_success(self, app):
|
||||
def test_draft_run_success(self, app: Flask):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -325,7 +326,7 @@ class TestPipelineRunApis:
|
||||
):
|
||||
assert method(api, pipeline) == {"ok": True}
|
||||
|
||||
def test_draft_run_rate_limit(self, app):
|
||||
def test_draft_run_rate_limit(self, app: Flask):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -356,10 +357,10 @@ class TestPipelineRunApis:
|
||||
|
||||
class TestDraftNodeRun:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_execution_not_found(self, app):
|
||||
def test_execution_not_found(self, app: Flask):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -387,10 +388,10 @@ class TestDraftNodeRun:
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_publish_success(self, app, db_session_with_containers: Session):
|
||||
def test_publish_success(self, app: Flask, db_session_with_containers: Session):
|
||||
from models.dataset import Pipeline
|
||||
|
||||
api = PublishedRagPipelineApi()
|
||||
@ -436,10 +437,10 @@ class TestPublishedPipelineApis:
|
||||
|
||||
class TestMiscApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_task_stop(self, app):
|
||||
def test_task_stop(self, app: Flask):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -460,7 +461,7 @@ class TestMiscApis:
|
||||
stop_mock.assert_called_once()
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_transform_forbidden(self, app):
|
||||
def test_transform_forbidden(self, app: Flask):
|
||||
api = RagPipelineTransformApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -476,7 +477,7 @@ class TestMiscApis:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds1")
|
||||
|
||||
def test_recommended_plugins(self, app):
|
||||
def test_recommended_plugins(self, app: Flask):
|
||||
api = RagPipelineRecommendedPluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -496,10 +497,10 @@ class TestMiscApis:
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_published_run_success(self, app):
|
||||
def test_published_run_success(self, app: Flask):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -533,7 +534,7 @@ class TestPublishedRagPipelineRunApi:
|
||||
result = method(api, pipeline)
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_published_run_rate_limit(self, app):
|
||||
def test_published_run_rate_limit(self, app: Flask):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -565,10 +566,10 @@ class TestPublishedRagPipelineRunApi:
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_block_config_success(self, app):
|
||||
def test_get_block_config_success(self, app: Flask):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -587,7 +588,7 @@ class TestDefaultBlockConfigApi:
|
||||
result = method(api, pipeline, "llm")
|
||||
assert result == {"k": "v"}
|
||||
|
||||
def test_get_block_config_invalid_json(self, app):
|
||||
def test_get_block_config_invalid_json(self, app: Flask):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -600,10 +601,10 @@ class TestDefaultBlockConfigApi:
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_published_workflows_success(self, app):
|
||||
def test_get_published_workflows_success(self, app: Flask):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -629,7 +630,7 @@ class TestPublishedAllRagPipelineApi:
|
||||
assert result["items"] == [{"id": "w1"}]
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app):
|
||||
def test_get_published_workflows_forbidden(self, app: Flask):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -649,10 +650,10 @@ class TestPublishedAllRagPipelineApi:
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
def test_patch_success(self, app: Flask):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -682,7 +683,7 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
assert result == workflow
|
||||
|
||||
def test_patch_no_fields(self, app):
|
||||
def test_patch_no_fields(self, app: Flask):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -700,7 +701,7 @@ class TestRagPipelineByIdApi:
|
||||
result, status = method(api, pipeline, "w1")
|
||||
assert status == 400
|
||||
|
||||
def test_delete_success(self, app):
|
||||
def test_delete_success(self, app: Flask):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@ -720,7 +721,7 @@ class TestRagPipelineByIdApi:
|
||||
workflow_service.delete_workflow.assert_called_once()
|
||||
assert result == (None, 204)
|
||||
|
||||
def test_delete_active_workflow_rejected(self, app):
|
||||
def test_delete_active_workflow_rejected(self, app: Flask):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@ -733,10 +734,10 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_last_run_success(self, app):
|
||||
def test_last_run_success(self, app: Flask):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -758,7 +759,7 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
result = method(api, pipeline, "node1")
|
||||
assert result == node_exec
|
||||
|
||||
def test_last_run_not_found(self, app):
|
||||
def test_last_run_not_found(self, app: Flask):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -780,10 +781,10 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
def test_set_datasource_variables_success(self, app: Flask):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.datasets import data_source
|
||||
@ -51,10 +52,10 @@ def mock_engine():
|
||||
|
||||
class TestDataSourceApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
def test_get_success(self, app: Flask, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -78,7 +79,7 @@ class TestDataSourceApi:
|
||||
assert status == 200
|
||||
assert response["data"][0]["is_bound"] is True
|
||||
|
||||
def test_get_no_bindings(self, app, patch_tenant):
|
||||
def test_get_no_bindings(self, app: Flask, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -94,7 +95,7 @@ class TestDataSourceApi:
|
||||
assert status == 200
|
||||
assert response["data"] == []
|
||||
|
||||
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
|
||||
def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -115,7 +116,7 @@ class TestDataSourceApi:
|
||||
assert status == 200
|
||||
assert binding.disabled is False
|
||||
|
||||
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
|
||||
def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -136,7 +137,7 @@ class TestDataSourceApi:
|
||||
assert status == 200
|
||||
assert binding.disabled is True
|
||||
|
||||
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
|
||||
def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -151,7 +152,7 @@ class TestDataSourceApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "b1", "enable")
|
||||
|
||||
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
|
||||
def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -168,7 +169,7 @@ class TestDataSourceApi:
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "enable")
|
||||
|
||||
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
|
||||
def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -188,10 +189,10 @@ class TestDataSourceApi:
|
||||
|
||||
class TestDataSourceNotionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_credential_not_found(self, app, patch_tenant):
|
||||
def test_get_credential_not_found(self, app: Flask, patch_tenant):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -205,7 +206,7 @@ class TestDataSourceNotionListApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
|
||||
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
|
||||
def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -246,7 +247,7 @@ class TestDataSourceNotionListApi:
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
|
||||
def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -299,7 +300,7 @@ class TestDataSourceNotionListApi:
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
|
||||
def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -323,10 +324,10 @@ class TestDataSourceNotionListApi:
|
||||
|
||||
class TestDataSourceNotionApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_preview_success(self, app, patch_tenant):
|
||||
def test_get_preview_success(self, app: Flask, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -347,7 +348,7 @@ class TestDataSourceNotionApi:
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_indexing_estimate_success(self, app, patch_tenant):
|
||||
def test_post_indexing_estimate_success(self, app: Flask, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -381,10 +382,10 @@ class TestDataSourceNotionApi:
|
||||
|
||||
class TestDataSourceNotionDatasetSyncApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
def test_get_success(self, app: Flask, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -407,7 +408,7 @@ class TestDataSourceNotionDatasetSyncApi:
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_dataset_not_found(self, app, patch_tenant):
|
||||
def test_get_dataset_not_found(self, app: Flask, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -424,10 +425,10 @@ class TestDataSourceNotionDatasetSyncApi:
|
||||
|
||||
class TestDataSourceNotionDocumentSyncApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
def test_get_success(self, app: Flask, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -450,7 +451,7 @@ class TestDataSourceNotionDocumentSyncApi:
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_document_not_found(self, app, patch_tenant):
|
||||
def test_get_document_not_found(self, app: Flask, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.conversation as conversation_module
|
||||
@ -53,10 +54,10 @@ def user():
|
||||
|
||||
class TestConversationListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, chat_app, user):
|
||||
def test_get_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -81,7 +82,7 @@ class TestConversationListApi:
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_last_conversation_not_exists(self, app, chat_app, user):
|
||||
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -97,7 +98,7 @@ class TestConversationListApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app)
|
||||
|
||||
def test_wrong_app_mode(self, app, non_chat_app):
|
||||
def test_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -108,10 +109,10 @@ class TestConversationListApi:
|
||||
|
||||
class TestConversationApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_delete_success(self, app, chat_app, user):
|
||||
def test_delete_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@ -129,7 +130,7 @@ class TestConversationApi:
|
||||
assert status == 204
|
||||
assert body["result"] == "success"
|
||||
|
||||
def test_delete_not_found(self, app, chat_app, user):
|
||||
def test_delete_not_found(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@ -145,7 +146,7 @@ class TestConversationApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
|
||||
def test_delete_wrong_app_mode(self, app, non_chat_app):
|
||||
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@ -156,10 +157,10 @@ class TestConversationApi:
|
||||
|
||||
class TestConversationRenameApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_rename_success(self, app, chat_app, user):
|
||||
def test_rename_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -178,7 +179,7 @@ class TestConversationRenameApi:
|
||||
|
||||
assert result["id"] == "cid"
|
||||
|
||||
def test_rename_not_found(self, app, chat_app, user):
|
||||
def test_rename_not_found(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -197,10 +198,10 @@ class TestConversationRenameApi:
|
||||
|
||||
class TestConversationPinApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_pin_success(self, app, chat_app, user):
|
||||
def test_pin_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -219,10 +220,10 @@ class TestConversationPinApi:
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_unpin_success(self, app, chat_app, user):
|
||||
def test_unpin_success(self, app: Flask, chat_app, user):
|
||||
api = conversation_module.ConversationUnPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.tool_providers import (
|
||||
@ -60,7 +61,7 @@ def _mock_user_tenant():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(flask_app_with_containers):
|
||||
def client(flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers.test_client()
|
||||
|
||||
|
||||
@ -147,10 +148,10 @@ class TestUtils:
|
||||
|
||||
class TestToolProviderListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
def test_get_success(self, app: Flask):
|
||||
api = ToolProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -170,10 +171,10 @@ class TestToolProviderListApi:
|
||||
|
||||
class TestBuiltinProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_list_tools(self, app):
|
||||
def test_list_tools(self, app: Flask):
|
||||
api = ToolBuiltinProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -190,7 +191,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider") == [{"a": 1}]
|
||||
|
||||
def test_info(self, app):
|
||||
def test_info(self, app: Flask):
|
||||
api = ToolBuiltinProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -207,7 +208,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider") == {"x": 1}
|
||||
|
||||
def test_delete(self, app):
|
||||
def test_delete(self, app: Flask):
|
||||
api = ToolBuiltinProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -224,7 +225,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider")["result"] == "success"
|
||||
|
||||
def test_add_invalid_type(self, app):
|
||||
def test_add_invalid_type(self, app: Flask):
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -238,7 +239,7 @@ class TestBuiltinProviderApis:
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "provider")
|
||||
|
||||
def test_add_success(self, app):
|
||||
def test_add_success(self, app: Flask):
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -257,7 +258,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider")["id"] == 1
|
||||
|
||||
def test_update(self, app):
|
||||
def test_update(self, app: Flask):
|
||||
api = ToolBuiltinProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -276,7 +277,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_credentials(self, app):
|
||||
def test_get_credentials(self, app: Flask):
|
||||
api = ToolBuiltinProviderGetCredentialsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -293,7 +294,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider") == {"k": "v"}
|
||||
|
||||
def test_icon(self, app):
|
||||
def test_icon(self, app: Flask):
|
||||
api = ToolBuiltinProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -307,7 +308,7 @@ class TestBuiltinProviderApis:
|
||||
response = method(api, "provider")
|
||||
assert response.mimetype == "image/png"
|
||||
|
||||
def test_credentials_schema(self, app):
|
||||
def test_credentials_schema(self, app: Flask):
|
||||
api = ToolBuiltinProviderCredentialsSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -324,7 +325,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider", "oauth2") == {"schema": {}}
|
||||
|
||||
def test_set_default_credential(self, app):
|
||||
def test_set_default_credential(self, app: Flask):
|
||||
api = ToolBuiltinProviderSetDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -341,7 +342,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_credential_info(self, app):
|
||||
def test_get_credential_info(self, app: Flask):
|
||||
api = ToolBuiltinProviderGetCredentialInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -358,7 +359,7 @@ class TestBuiltinProviderApis:
|
||||
):
|
||||
assert method(api, "provider") == {"info": "x"}
|
||||
|
||||
def test_get_oauth_client_schema(self, app):
|
||||
def test_get_oauth_client_schema(self, app: Flask):
|
||||
api = ToolBuiltinProviderGetOauthClientSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -378,10 +379,10 @@ class TestBuiltinProviderApis:
|
||||
|
||||
class TestApiProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_add(self, app):
|
||||
def test_add(self, app: Flask):
|
||||
api = ToolApiProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -406,7 +407,7 @@ class TestApiProviderApis:
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
|
||||
def test_remote_schema(self, app):
|
||||
def test_remote_schema(self, app: Flask):
|
||||
api = ToolApiProviderGetRemoteSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -423,7 +424,7 @@ class TestApiProviderApis:
|
||||
):
|
||||
assert method(api)["schema"] == "x"
|
||||
|
||||
def test_list_tools(self, app):
|
||||
def test_list_tools(self, app: Flask):
|
||||
api = ToolApiProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -440,7 +441,7 @@ class TestApiProviderApis:
|
||||
):
|
||||
assert method(api) == [{"tool": 1}]
|
||||
|
||||
def test_update(self, app):
|
||||
def test_update(self, app: Flask):
|
||||
api = ToolApiProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -468,7 +469,7 @@ class TestApiProviderApis:
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
|
||||
def test_delete(self, app):
|
||||
def test_delete(self, app: Flask):
|
||||
api = ToolApiProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -485,7 +486,7 @@ class TestApiProviderApis:
|
||||
):
|
||||
assert method(api)["result"] == "success"
|
||||
|
||||
def test_get(self, app):
|
||||
def test_get(self, app: Flask):
|
||||
api = ToolApiProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -505,10 +506,10 @@ class TestApiProviderApis:
|
||||
|
||||
class TestWorkflowApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_create(self, app):
|
||||
def test_create(self, app: Flask):
|
||||
api = ToolWorkflowProviderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -534,7 +535,7 @@ class TestWorkflowApis:
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
|
||||
def test_update_invalid(self, app):
|
||||
def test_update_invalid(self, app: Flask):
|
||||
api = ToolWorkflowProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -560,7 +561,7 @@ class TestWorkflowApis:
|
||||
result = method(api)
|
||||
assert result["ok"]
|
||||
|
||||
def test_delete(self, app):
|
||||
def test_delete(self, app: Flask):
|
||||
api = ToolWorkflowProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -577,7 +578,7 @@ class TestWorkflowApis:
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
|
||||
def test_get_error(self, app):
|
||||
def test_get_error(self, app: Flask):
|
||||
api = ToolWorkflowProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -594,10 +595,10 @@ class TestWorkflowApis:
|
||||
|
||||
class TestLists:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_builtin_list(self, app):
|
||||
def test_builtin_list(self, app: Flask):
|
||||
api = ToolBuiltinListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -617,7 +618,7 @@ class TestLists:
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
def test_api_list(self, app):
|
||||
def test_api_list(self, app: Flask):
|
||||
api = ToolApiListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -637,7 +638,7 @@ class TestLists:
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
def test_workflow_list(self, app):
|
||||
def test_workflow_list(self, app: Flask):
|
||||
api = ToolWorkflowListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -660,10 +661,10 @@ class TestLists:
|
||||
|
||||
class TestLabels:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_labels(self, app):
|
||||
def test_labels(self, app: Flask):
|
||||
api = ToolLabelsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -679,10 +680,10 @@ class TestLabels:
|
||||
|
||||
class TestOAuth:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_oauth_no_client(self, app):
|
||||
def test_oauth_no_client(self, app: Flask):
|
||||
api = ToolPluginOAuthApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -700,7 +701,7 @@ class TestOAuth:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "provider")
|
||||
|
||||
def test_oauth_callback_no_cookie(self, app):
|
||||
def test_oauth_callback_no_cookie(self, app: Flask):
|
||||
api = ToolOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -711,10 +712,10 @@ class TestOAuth:
|
||||
|
||||
class TestOAuthCustomClient:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_save_custom_client(self, app):
|
||||
def test_save_custom_client(self, app: Flask):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -731,7 +732,7 @@ class TestOAuthCustomClient:
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_custom_client(self, app):
|
||||
def test_get_custom_client(self, app: Flask):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -748,7 +749,7 @@ class TestOAuthCustomClient:
|
||||
):
|
||||
assert method(api, "provider") == {"client_id": "x"}
|
||||
|
||||
def test_delete_custom_client(self, app):
|
||||
def test_delete_custom_client(self, app: Flask):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from controllers.console.workspace.trigger_providers import (
|
||||
@ -45,10 +46,10 @@ def mock_user():
|
||||
|
||||
class TestTriggerProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_icon_success(self, app):
|
||||
def test_icon_success(self, app: Flask):
|
||||
api = TriggerProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -62,7 +63,7 @@ class TestTriggerProviderApis:
|
||||
):
|
||||
assert method(api, "github") == "icon"
|
||||
|
||||
def test_list_providers(self, app):
|
||||
def test_list_providers(self, app: Flask):
|
||||
api = TriggerProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -76,7 +77,7 @@ class TestTriggerProviderApis:
|
||||
):
|
||||
assert method(api) == []
|
||||
|
||||
def test_provider_info(self, app):
|
||||
def test_provider_info(self, app: Flask):
|
||||
api = TriggerProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -93,10 +94,10 @@ class TestTriggerProviderApis:
|
||||
|
||||
class TestTriggerSubscriptionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_list_success(self, app):
|
||||
def test_list_success(self, app: Flask):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -110,7 +111,7 @@ class TestTriggerSubscriptionListApi:
|
||||
):
|
||||
assert method(api, "github") == []
|
||||
|
||||
def test_list_invalid_provider(self, app):
|
||||
def test_list_invalid_provider(self, app: Flask):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -128,10 +129,10 @@ class TestTriggerSubscriptionListApi:
|
||||
|
||||
class TestTriggerSubscriptionBuilderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_create_builder(self, app):
|
||||
def test_create_builder(self, app: Flask):
|
||||
api = TriggerSubscriptionBuilderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -146,7 +147,7 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
result = method(api, "github")
|
||||
assert "subscription_builder" in result
|
||||
|
||||
def test_get_builder(self, app):
|
||||
def test_get_builder(self, app: Flask):
|
||||
api = TriggerSubscriptionBuilderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -159,7 +160,7 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_verify_builder(self, app):
|
||||
def test_verify_builder(self, app: Flask):
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -173,7 +174,7 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
):
|
||||
assert method(api, "github", "b1") == {"ok": True}
|
||||
|
||||
def test_verify_builder_error(self, app):
|
||||
def test_verify_builder_error(self, app: Flask):
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -188,7 +189,7 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "github", "b1")
|
||||
|
||||
def test_update_builder(self, app):
|
||||
def test_update_builder(self, app: Flask):
|
||||
api = TriggerSubscriptionBuilderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -202,7 +203,7 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_logs(self, app):
|
||||
def test_logs(self, app: Flask):
|
||||
api = TriggerSubscriptionBuilderLogsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -219,7 +220,7 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
):
|
||||
assert "logs" in method(api, "github", "b1")
|
||||
|
||||
def test_build(self, app):
|
||||
def test_build(self, app: Flask):
|
||||
api = TriggerSubscriptionBuilderBuildApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -236,10 +237,10 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
|
||||
class TestTriggerSubscriptionCrud:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_update_rename_only(self, app):
|
||||
def test_update_rename_only(self, app: Flask):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -258,7 +259,7 @@ class TestTriggerSubscriptionCrud:
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
|
||||
def test_update_not_found(self, app):
|
||||
def test_update_not_found(self, app: Flask):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -273,7 +274,7 @@ class TestTriggerSubscriptionCrud:
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "x")
|
||||
|
||||
def test_update_rebuild(self, app):
|
||||
def test_update_rebuild(self, app: Flask):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -296,7 +297,7 @@ class TestTriggerSubscriptionCrud:
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
|
||||
def test_delete_subscription(self, app):
|
||||
def test_delete_subscription(self, app: Flask):
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -319,7 +320,7 @@ class TestTriggerSubscriptionCrud:
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_subscription_value_error(self, app):
|
||||
def test_delete_subscription_value_error(self, app: Flask):
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -342,10 +343,10 @@ class TestTriggerSubscriptionCrud:
|
||||
|
||||
class TestTriggerOAuthApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_oauth_authorize_success(self, app):
|
||||
def test_oauth_authorize_success(self, app: Flask):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -372,7 +373,7 @@ class TestTriggerOAuthApis:
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_oauth_authorize_no_client(self, app):
|
||||
def test_oauth_authorize_no_client(self, app: Flask):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -387,7 +388,7 @@ class TestTriggerOAuthApis:
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_forbidden(self, app):
|
||||
def test_oauth_callback_forbidden(self, app: Flask):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -395,7 +396,7 @@ class TestTriggerOAuthApis:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_success(self, app):
|
||||
def test_oauth_callback_success(self, app: Flask):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -425,7 +426,7 @@ class TestTriggerOAuthApis:
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 302
|
||||
|
||||
def test_oauth_callback_no_oauth_client(self, app):
|
||||
def test_oauth_callback_no_oauth_client(self, app: Flask):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -449,7 +450,7 @@ class TestTriggerOAuthApis:
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_empty_credentials(self, app):
|
||||
def test_oauth_callback_empty_credentials(self, app: Flask):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -480,10 +481,10 @@ class TestTriggerOAuthApis:
|
||||
|
||||
class TestTriggerOAuthClientManageApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_client(self, app):
|
||||
def test_get_client(self, app: Flask):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -510,7 +511,7 @@ class TestTriggerOAuthClientManageApi:
|
||||
result = method(api, "github")
|
||||
assert "configured" in result
|
||||
|
||||
def test_post_client(self, app):
|
||||
def test_post_client(self, app: Flask):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -524,7 +525,7 @@ class TestTriggerOAuthClientManageApi:
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
|
||||
def test_delete_client(self, app):
|
||||
def test_delete_client(self, app: Flask):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@ -538,7 +539,7 @@ class TestTriggerOAuthClientManageApi:
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
|
||||
def test_oauth_client_post_value_error(self, app):
|
||||
def test_oauth_client_post_value_error(self, app: Flask):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -556,10 +557,10 @@ class TestTriggerOAuthClientManageApi:
|
||||
|
||||
class TestTriggerSubscriptionVerifyApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_verify_success(self, app):
|
||||
def test_verify_success(self, app: Flask):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -217,10 +218,20 @@ class TestTagUnbindingPayload:
|
||||
"""Test suite for TagUnbindingPayload Pydantic model."""
|
||||
|
||||
def test_payload_with_valid_data(self):
|
||||
payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456")
|
||||
assert payload.tag_id == "tag_123"
|
||||
payload = TagUnbindingPayload(tag_ids=["tag_123"], target_id="dataset_456")
|
||||
assert payload.tag_ids == ["tag_123"]
|
||||
assert payload.target_id == "dataset_456"
|
||||
|
||||
def test_payload_normalizes_legacy_tag_id(self):
|
||||
payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456")
|
||||
assert payload.tag_ids == ["tag_123"]
|
||||
assert payload.target_id == "dataset_456"
|
||||
|
||||
def test_payload_rejects_empty_tag_ids(self):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
TagUnbindingPayload(tag_ids=[], target_id="dataset_456")
|
||||
assert "Tag IDs is required" in str(exc_info.value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@ -236,7 +247,7 @@ def _unwrap(method):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(flask_app_with_containers):
|
||||
def app(flask_app_with_containers: Flask):
|
||||
# Uses the full containerised app so that Flask config, extensions, and
|
||||
# blueprint registrations match production. Most tests mock the service
|
||||
# layer to isolate controller logic; a few (e.g. test_list_tags_from_db)
|
||||
@ -280,7 +291,7 @@ class TestDatasetListApiGet:
|
||||
mock_current_user,
|
||||
mock_provider_mgr,
|
||||
mock_marshal,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_tenant,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetListApi
|
||||
@ -315,7 +326,7 @@ class TestDatasetListApiPost:
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
mock_marshal,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_tenant,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetListApi
|
||||
@ -341,7 +352,7 @@ class TestDatasetListApiPost:
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_tenant,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetListApi
|
||||
@ -379,7 +390,7 @@ class TestDatasetApiGet:
|
||||
mock_provider_mgr,
|
||||
mock_marshal,
|
||||
mock_perm_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_dataset,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetApi
|
||||
@ -429,7 +440,7 @@ class TestDatasetApiGet:
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_dataset,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetApi
|
||||
@ -457,7 +468,7 @@ class TestDatasetApiDelete:
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
mock_perm_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_dataset,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetApi
|
||||
@ -479,7 +490,7 @@ class TestDatasetApiDelete:
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_dataset,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetApi
|
||||
@ -500,7 +511,7 @@ class TestDatasetApiDelete:
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_dataset,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetApi
|
||||
@ -532,7 +543,7 @@ class TestDocumentStatusApiPatch:
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
mock_doc_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
@ -563,7 +574,7 @@ class TestDocumentStatusApiPatch:
|
||||
def test_batch_update_status_dataset_not_found(
|
||||
self,
|
||||
mock_dataset_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
@ -592,7 +603,7 @@ class TestDocumentStatusApiPatch:
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
mock_doc_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
@ -625,7 +636,7 @@ class TestDocumentStatusApiPatch:
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
mock_doc_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
@ -658,7 +669,7 @@ class TestDocumentStatusApiPatch:
|
||||
mock_dataset_svc,
|
||||
mock_current_user,
|
||||
mock_doc_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
):
|
||||
@ -698,7 +709,7 @@ class TestDatasetTagsApiGet:
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_tag_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagsApi
|
||||
|
||||
@ -720,7 +731,7 @@ class TestDatasetTagsApiGet:
|
||||
def test_list_tags_from_db(
|
||||
self,
|
||||
mock_current_user,
|
||||
app,
|
||||
app: Flask,
|
||||
db_session_with_containers: Session,
|
||||
):
|
||||
"""Integration test: creates real Tag rows and retrieves them
|
||||
@ -763,7 +774,7 @@ class TestDatasetTagsApiPost:
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_tag_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagsApi
|
||||
|
||||
@ -786,7 +797,7 @@ class TestDatasetTagsApiPost:
|
||||
mock_tag_svc.save_tags.assert_called_once()
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
def test_create_tag_forbidden(self, mock_current_user, app):
|
||||
def test_create_tag_forbidden(self, mock_current_user, app: Flask):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagsApi
|
||||
|
||||
mock_current_user.__class__ = Account
|
||||
@ -815,7 +826,7 @@ class TestDatasetTagsApiPatch:
|
||||
mock_current_user,
|
||||
mock_service_api_ns,
|
||||
mock_tag_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagsApi
|
||||
|
||||
@ -841,7 +852,7 @@ class TestDatasetTagsApiPatch:
|
||||
mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1")
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
def test_update_tag_forbidden(self, mock_current_user, app):
|
||||
def test_update_tag_forbidden(self, mock_current_user, app: Flask):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagsApi
|
||||
|
||||
mock_current_user.__class__ = Account
|
||||
@ -869,7 +880,7 @@ class TestDatasetTagsApiDelete:
|
||||
mock_current_user,
|
||||
mock_service_api_ns,
|
||||
mock_tag_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagsApi
|
||||
|
||||
@ -894,7 +905,7 @@ class TestDatasetTagsApiDelete:
|
||||
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
|
||||
|
||||
@patch("libs.login.current_user")
|
||||
def test_delete_tag_forbidden(self, mock_current_user, app):
|
||||
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagsApi
|
||||
|
||||
user_obj = Mock(spec=Account)
|
||||
@ -922,7 +933,7 @@ class TestDatasetTagsBindingStatusApi:
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_tag_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi
|
||||
|
||||
@ -952,7 +963,7 @@ class TestDatasetTagBindingApiPost:
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_tag_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
|
||||
|
||||
@ -977,7 +988,7 @@ class TestDatasetTagBindingApiPost:
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
def test_bind_tags_forbidden(self, mock_current_user, app):
|
||||
def test_bind_tags_forbidden(self, mock_current_user, app: Flask):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
|
||||
|
||||
mock_current_user.__class__ = Account
|
||||
@ -1003,7 +1014,37 @@ class TestDatasetTagUnbindingApiPost:
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_tag_svc,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
|
||||
|
||||
mock_current_user.__class__ = Account
|
||||
mock_current_user.has_edit_permission = True
|
||||
mock_current_user.is_dataset_editor = True
|
||||
mock_tag_svc.delete_tag_binding.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
"/datasets/tags/unbinding",
|
||||
method="POST",
|
||||
json={"tag_ids": ["tag-1"], "target_id": "ds-1"},
|
||||
):
|
||||
api = DatasetTagUnbindingApi()
|
||||
result = api.post(_=None)
|
||||
|
||||
assert result == ("", 204)
|
||||
from services.tag_service import TagBindingDeletePayload
|
||||
|
||||
mock_tag_svc.delete_tag_binding.assert_called_once_with(
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge")
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.TagService")
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
def test_unbind_legacy_tag_id_success(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_tag_svc,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
|
||||
|
||||
@ -1024,11 +1065,11 @@ class TestDatasetTagUnbindingApiPost:
|
||||
from services.tag_service import TagBindingDeletePayload
|
||||
|
||||
mock_tag_svc.delete_tag_binding.assert_called_once_with(
|
||||
TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge")
|
||||
TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge")
|
||||
)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
def test_unbind_tag_forbidden(self, mock_current_user, app):
|
||||
def test_unbind_tag_forbidden(self, mock_current_user, app: Flask):
|
||||
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
|
||||
|
||||
mock_current_user.__class__ = Account
|
||||
@ -1038,7 +1079,7 @@ class TestDatasetTagUnbindingApiPost:
|
||||
with app.test_request_context(
|
||||
"/datasets/tags/unbinding",
|
||||
method="POST",
|
||||
json={"tag_id": "tag-1", "target_id": "ds-1"},
|
||||
json={"tag_ids": ["tag-1"], "target_id": "ds-1"},
|
||||
):
|
||||
api = DatasetTagUnbindingApi()
|
||||
with pytest.raises(Forbidden):
|
||||
|
||||
@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.conversation import (
|
||||
@ -34,16 +35,16 @@ def _end_user() -> SimpleNamespace:
|
||||
|
||||
class TestConversationListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/conversations"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationListApi().get(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
|
||||
def test_happy_path(self, mock_paginate: MagicMock, app) -> None:
|
||||
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
conv_id = str(uuid4())
|
||||
conv = SimpleNamespace(
|
||||
id=conv_id,
|
||||
@ -65,16 +66,16 @@ class TestConversationListApi:
|
||||
|
||||
class TestConversationApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app) -> None:
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
@ -83,7 +84,7 @@ class TestConversationApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app) -> None:
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
@ -92,17 +93,17 @@ class TestConversationApi:
|
||||
|
||||
class TestConversationRenameApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.rename")
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "New Name", "auto_generate": False}
|
||||
conv = SimpleNamespace(
|
||||
@ -126,7 +127,7 @@ class TestConversationRenameApi:
|
||||
side_effect=ConversationNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "X", "auto_generate": False}
|
||||
|
||||
@ -137,16 +138,16 @@ class TestConversationRenameApi:
|
||||
|
||||
class TestConversationPinApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin")
|
||||
def test_pin_success(self, mock_pin: MagicMock, app) -> None:
|
||||
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
@ -154,7 +155,7 @@ class TestConversationPinApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app) -> None:
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
with pytest.raises(NotFound):
|
||||
@ -163,16 +164,16 @@ class TestConversationPinApi:
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.unpin")
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app) -> None:
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
|
||||
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
@ -7,6 +7,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
@ -29,7 +30,7 @@ def _patch_wraps():
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
|
||||
@ -42,7 +43,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
mock_rate_limit,
|
||||
mock_get_account,
|
||||
mock_send_mail,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
@ -64,7 +65,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@ -81,7 +82,7 @@ class TestForgotPasswordCheckApi:
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"}
|
||||
@ -117,7 +118,7 @@ class TestForgotPasswordCheckApi:
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"}
|
||||
@ -142,7 +143,7 @@ class TestForgotPasswordCheckApi:
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@ -157,7 +158,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_db,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
|
||||
mock_account = MagicMock()
|
||||
@ -194,7 +195,7 @@ class TestForgotPasswordResetApi:
|
||||
mock_db,
|
||||
mock_token_bytes,
|
||||
mock_hash_password,
|
||||
app,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
|
||||
account = MagicMock()
|
||||
|
||||
@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
@ -182,7 +183,7 @@ class TestValidateUserAccessibility:
|
||||
|
||||
class TestDecodeJwtToken:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True):
|
||||
@ -239,7 +240,7 @@ class TestDecodeJwtToken:
|
||||
mock_access_mode: MagicMock,
|
||||
mock_validate_token: MagicMock,
|
||||
mock_validate_user: MagicMock,
|
||||
app,
|
||||
app: Flask,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
|
||||
@ -299,7 +300,7 @@ class TestDecodeJwtToken:
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app,
|
||||
app: Flask,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False)
|
||||
@ -324,7 +325,7 @@ class TestDecodeJwtToken:
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app,
|
||||
app: Flask,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, _ = self._create_app_site_enduser(db_session_with_containers)
|
||||
@ -350,7 +351,7 @@ class TestDecodeJwtToken:
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app,
|
||||
app: Flask,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
|
||||
|
||||
@ -85,7 +85,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
return WorkflowRunService(engine)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
|
||||
def setup_test_data(self, db_session_with_containers: Session, file_service, workflow_run_service):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
|
||||
@ -295,7 +295,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
generate_entity=entity,
|
||||
)
|
||||
|
||||
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
|
||||
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers: Session):
|
||||
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
@ -352,7 +352,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
assert isinstance(persisted_entity, WorkflowAppGenerateEntity)
|
||||
assert persisted_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
def test_state_persistence_and_retrieval(self, db_session_with_containers):
|
||||
def test_state_persistence_and_retrieval(self, db_session_with_containers: Session):
|
||||
"""Test that pause state can be persisted and retrieved correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
@ -402,7 +402,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
assert retrieved_state["node_run_steps"] == 10
|
||||
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
def test_database_transaction_handling(self, db_session_with_containers):
|
||||
def test_database_transaction_handling(self, db_session_with_containers: Session):
|
||||
"""Test that database transactions are handled correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
@ -433,7 +433,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
def test_file_storage_integration(self, db_session_with_containers):
|
||||
def test_file_storage_integration(self, db_session_with_containers: Session):
|
||||
"""Test integration with file storage system."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
@ -467,7 +467,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
|
||||
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
def test_workflow_with_different_creators(self, db_session_with_containers):
|
||||
def test_workflow_with_different_creators(self, db_session_with_containers: Session):
|
||||
"""Test pause state with workflows created by different users."""
|
||||
# Arrange - Create workflow with different creator
|
||||
different_user_id = str(uuid.uuid4())
|
||||
@ -532,7 +532,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id
|
||||
|
||||
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
|
||||
def test_layer_ignores_non_pause_events(self, db_session_with_containers: Session):
|
||||
"""Test that layer ignores non-pause events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
@ -562,7 +562,7 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
).all()
|
||||
assert len(pause_states) == 0
|
||||
|
||||
def test_layer_requires_initialization(self, db_session_with_containers):
|
||||
def test_layer_requires_initialization(self, db_session_with_containers: Session):
|
||||
"""Test that layer requires proper initialization before handling events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
@ -15,11 +15,14 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
|
||||
|
||||
TenantAndAccount = tuple[Tenant, Account]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestTask:
|
||||
@ -40,7 +43,7 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
return Faker()
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_and_account(self, db_session_with_containers, fake):
|
||||
def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker):
|
||||
"""Create test tenant and account for testing."""
|
||||
# Create account
|
||||
account = Account(
|
||||
@ -73,18 +76,18 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
return tenant, account
|
||||
|
||||
@pytest.fixture
|
||||
def test_queue(self, test_tenant_and_account):
|
||||
def test_queue(self, test_tenant_and_account: TenantAndAccount):
|
||||
"""Create a generic test queue for testing."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
|
||||
|
||||
@pytest.fixture
|
||||
def secondary_queue(self, test_tenant_and_account):
|
||||
def secondary_queue(self, test_tenant_and_account: TenantAndAccount):
|
||||
"""Create a secondary test queue for testing isolation."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
|
||||
|
||||
def test_queue_initialization(self, test_tenant_and_account):
|
||||
def test_queue_initialization(self, test_tenant_and_account: TenantAndAccount):
|
||||
"""Test queue initialization with correct key generation."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
|
||||
@ -94,7 +97,9 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
|
||||
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
|
||||
|
||||
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
|
||||
def test_tenant_isolation(
|
||||
self, test_tenant_and_account: TenantAndAccount, db_session_with_containers: Session, fake: Faker
|
||||
):
|
||||
"""Test that different tenants have isolated queues."""
|
||||
tenant1, _ = test_tenant_and_account
|
||||
|
||||
@ -114,7 +119,7 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
|
||||
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
|
||||
|
||||
def test_key_isolation(self, test_tenant_and_account):
|
||||
def test_key_isolation(self, test_tenant_and_account: TenantAndAccount):
|
||||
"""Test that different keys have isolated queues."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
|
||||
@ -176,7 +181,7 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
assert len(remaining_tasks) == 2
|
||||
assert remaining_tasks == ["task4", "task5"]
|
||||
|
||||
def test_push_and_pull_complex_objects(self, test_queue, fake):
|
||||
def test_push_and_pull_complex_objects(self, test_queue, fake: Faker):
|
||||
"""Test pushing and pulling complex object tasks."""
|
||||
# Create complex task objects as dictionaries (not dataclass instances)
|
||||
tasks = [
|
||||
@ -218,7 +223,7 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
assert pulled_task["data"] == original_task["data"]
|
||||
assert pulled_task["metadata"] == original_task["metadata"]
|
||||
|
||||
def test_mixed_task_types(self, test_queue, fake):
|
||||
def test_mixed_task_types(self, test_queue, fake: Faker):
|
||||
"""Test pushing and pulling mixed string and object tasks."""
|
||||
string_task = "simple_string_task"
|
||||
object_task = {
|
||||
@ -267,7 +272,7 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
# Verify task key has expired
|
||||
assert test_queue.get_task_key() is None
|
||||
|
||||
def test_large_task_batch(self, test_queue, fake):
|
||||
def test_large_task_batch(self, test_queue, fake: Faker):
|
||||
"""Test handling large batches of tasks."""
|
||||
# Create large batch of tasks
|
||||
large_batch = []
|
||||
@ -292,7 +297,7 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
assert isinstance(task, dict)
|
||||
assert task["index"] == i # FIFO order
|
||||
|
||||
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
|
||||
def test_queue_operations_isolation(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
|
||||
"""Test concurrent operations on different queues."""
|
||||
tenant, _ = test_tenant_and_account
|
||||
|
||||
@ -312,7 +317,7 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
assert tasks2 == ["task1_queue2", "task2_queue2"]
|
||||
assert tasks1 != tasks2
|
||||
|
||||
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
|
||||
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake: Faker):
|
||||
"""Test TaskWrapper serialization and deserialization roundtrip."""
|
||||
# Create complex nested data
|
||||
complex_data = {
|
||||
@ -346,7 +351,7 @@ class TestTenantIsolatedTaskQueueIntegration:
|
||||
task = test_queue.pull_tasks(1)
|
||||
assert task[0] == invalid_json_task
|
||||
|
||||
def test_real_world_batch_processing_scenario(self, test_queue, fake):
|
||||
def test_real_world_batch_processing_scenario(self, test_queue, fake: Faker):
|
||||
"""Test realistic batch processing scenario."""
|
||||
# Simulate batch processing tasks
|
||||
batch_tasks = []
|
||||
@ -403,7 +408,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
|
||||
return Faker()
|
||||
|
||||
@pytest.fixture
|
||||
def test_tenant_and_account(self, db_session_with_containers, fake):
|
||||
def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker):
|
||||
"""Create test tenant and account for testing."""
|
||||
# Create account
|
||||
account = Account(
|
||||
@ -435,7 +440,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
|
||||
|
||||
return tenant, account
|
||||
|
||||
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
|
||||
def test_legacy_string_queue_compatibility(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
|
||||
"""
|
||||
Test compatibility with legacy queues containing only string data.
|
||||
|
||||
@ -465,7 +470,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
|
||||
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
|
||||
assert pulled_tasks == expected_order
|
||||
|
||||
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
|
||||
def test_legacy_queue_migration_scenario(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
|
||||
"""
|
||||
Test complete migration scenario from legacy to new system.
|
||||
|
||||
@ -546,7 +551,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
|
||||
assert task["tenant_id"] == tenant.id
|
||||
assert task["processing_type"] == "new_system"
|
||||
|
||||
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
|
||||
def test_legacy_queue_error_recovery(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
|
||||
"""
|
||||
Test error recovery when legacy queue contains malformed data.
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
@ -15,7 +16,7 @@ from tests.test_containers_integration_tests.helpers import generate_valid_passw
|
||||
|
||||
class TestGetAvailableDatasetsIntegration:
|
||||
def test_returns_datasets_with_available_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
@ -77,7 +78,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
assert result[0].name == dataset.name
|
||||
|
||||
def test_filters_out_datasets_with_only_archived_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
@ -130,7 +131,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
assert len(result) == 0
|
||||
|
||||
def test_filters_out_datasets_with_only_disabled_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
@ -183,7 +184,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
assert len(result) == 0
|
||||
|
||||
def test_filters_out_datasets_with_non_completed_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
@ -236,7 +237,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
assert len(result) == 0
|
||||
|
||||
def test_includes_external_datasets_without_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that external datasets are returned even with no available documents.
|
||||
@ -280,7 +281,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
assert result[0].id == dataset.id
|
||||
assert result[0].provider == "external"
|
||||
|
||||
def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_filters_by_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
@ -356,7 +357,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
assert result[0].tenant_id == tenant1.id
|
||||
|
||||
def test_returns_empty_list_when_no_datasets_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
@ -379,7 +380,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_returns_only_requested_dataset_ids(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
@ -439,7 +442,7 @@ class TestGetAvailableDatasetsIntegration:
|
||||
|
||||
class TestKnowledgeRetrievalIntegration:
|
||||
def test_knowledge_retrieval_with_available_datasets(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
@ -507,7 +510,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_knowledge_retrieval_no_available_datasets(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
@ -555,7 +558,7 @@ class TestKnowledgeRetrievalIntegration:
|
||||
assert result == []
|
||||
|
||||
def test_knowledge_retrieval_rate_limit_exceeded(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@ -16,7 +17,7 @@ from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||
class TestStorageKeyLoader(unittest.TestCase):
|
||||
class TestStorageKeyLoader:
|
||||
"""
|
||||
Integration tests for StorageKeyLoader class.
|
||||
|
||||
@ -24,110 +25,82 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test data before each test method."""
|
||||
self.session = db.session()
|
||||
self.tenant_id = str(uuid4())
|
||||
self.user_id = str(uuid4())
|
||||
self.conversation_id = str(uuid4())
|
||||
|
||||
# Create test data that will be cleaned up after each test
|
||||
self.test_upload_files = []
|
||||
self.test_tool_files = []
|
||||
|
||||
# Create StorageKeyLoader instance
|
||||
self.loader = StorageKeyLoader(
|
||||
self.session,
|
||||
self.tenant_id,
|
||||
access_controller=DatabaseFileAccessController(),
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test data after each test method."""
|
||||
self.session.rollback()
|
||||
# ------------------------------------------------------------------
|
||||
# Per-test helpers (use db_session_with_containers as parameter)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _create_upload_file(
|
||||
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
file_id: str | None = None,
|
||||
storage_key: str | None = None,
|
||||
override_tenant_id: str | None = None,
|
||||
) -> UploadFile:
|
||||
"""Helper method to create an UploadFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if storage_key is None:
|
||||
storage_key = f"test_storage_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
"""Create and flush an UploadFile record for testing."""
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
|
||||
storage_type=StorageType.LOCAL,
|
||||
key=storage_key,
|
||||
key=storage_key or f"test_storage_key_{uuid4()}",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_by=user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
)
|
||||
upload_file.id = file_id
|
||||
|
||||
self.session.add(upload_file)
|
||||
self.session.flush()
|
||||
self.test_upload_files.append(upload_file)
|
||||
|
||||
upload_file.id = file_id or str(uuid4())
|
||||
session.add(upload_file)
|
||||
session.flush()
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def _create_tool_file(
|
||||
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
*,
|
||||
file_id: str | None = None,
|
||||
file_key: str | None = None,
|
||||
override_tenant_id: str | None = None,
|
||||
) -> ToolFile:
|
||||
"""Helper method to create a ToolFile record for testing."""
|
||||
if file_id is None:
|
||||
file_id = str(uuid4())
|
||||
if file_key is None:
|
||||
file_key = f"test_file_key_{uuid4()}"
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
"""Create and flush a ToolFile record for testing."""
|
||||
tool_file = ToolFile(
|
||||
user_id=self.user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=self.conversation_id,
|
||||
file_key=file_key,
|
||||
user_id=user_id,
|
||||
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=file_key or f"test_file_key_{uuid4()}",
|
||||
mimetype="text/plain",
|
||||
original_url="http://example.com/file.txt",
|
||||
name="test_tool_file.txt",
|
||||
size=2048,
|
||||
)
|
||||
tool_file.id = file_id
|
||||
|
||||
self.session.add(tool_file)
|
||||
self.session.flush()
|
||||
self.test_tool_files.append(tool_file)
|
||||
|
||||
tool_file.id = file_id or str(uuid4())
|
||||
session.add(tool_file)
|
||||
session.flush()
|
||||
return tool_file
|
||||
|
||||
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
|
||||
"""Helper method to create a File object for testing."""
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
|
||||
file_related_id = None
|
||||
remote_url = None
|
||||
|
||||
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
|
||||
file_related_id = related_id
|
||||
elif transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
remote_url = "https://example.com/test_file.txt"
|
||||
file_related_id = related_id
|
||||
|
||||
@staticmethod
|
||||
def _create_file(
|
||||
tenant_id: str,
|
||||
related_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
*,
|
||||
override_tenant_id: str | None = None,
|
||||
) -> File:
|
||||
"""Build a File value-object for testing."""
|
||||
remote_url = "https://example.com/test_file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
|
||||
return File(
|
||||
file_id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=tenant_id,
|
||||
file_id=str(uuid4()),
|
||||
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=transfer_method,
|
||||
related_id=file_related_id,
|
||||
related_id=related_id,
|
||||
remote_url=remote_url,
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
@ -136,240 +109,280 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
def test_load_storage_keys_local_file(self):
|
||||
# ------------------------------------------------------------------
|
||||
# Tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_load_storage_keys_local_file(self, db_session_with_containers: Session):
|
||||
"""Test loading storage keys for LOCAL_FILE transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_remote_url(self):
|
||||
def test_load_storage_keys_remote_url(self, db_session_with_containers: Session):
|
||||
"""Test loading storage keys for REMOTE_URL transfer method."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_tool_file(self):
|
||||
def test_load_storage_keys_tool_file(self, db_session_with_containers: Session):
|
||||
"""Test loading storage keys for TOOL_FILE transfer method."""
|
||||
# Create test data
|
||||
tool_file = self._create_tool_file()
|
||||
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys([file])
|
||||
tool_file = self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id)
|
||||
file = self._create_file(tenant_id, related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
loader.load_storage_keys([file])
|
||||
|
||||
# Verify storage key was loaded correctly
|
||||
assert file._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_mixed_methods(self):
|
||||
def test_load_storage_keys_mixed_methods(self, db_session_with_containers: Session):
|
||||
"""Test batch loading with mixed transfer methods."""
|
||||
# Create test data for different transfer methods
|
||||
upload_file1 = self._create_upload_file()
|
||||
upload_file2 = self._create_upload_file()
|
||||
tool_file = self._create_tool_file()
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
upload_file1 = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
upload_file2 = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
tool_file = self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id)
|
||||
|
||||
files = [file1, file2, file3]
|
||||
file1 = self._create_file(tenant_id, related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(tenant_id, related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||
file3 = self._create_file(tenant_id, related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Load storage keys
|
||||
self.loader.load_storage_keys(files)
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
loader.load_storage_keys([file1, file2, file3])
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
assert file1._storage_key == upload_file1.key
|
||||
assert file2._storage_key == upload_file2.key
|
||||
assert file3._storage_key == tool_file.file_key
|
||||
|
||||
def test_load_storage_keys_empty_list(self):
|
||||
"""Test with empty file list."""
|
||||
# Should not raise any exceptions
|
||||
self.loader.load_storage_keys([])
|
||||
def test_load_storage_keys_empty_list(self, db_session_with_containers: Session):
|
||||
"""Test with empty file list — should not raise."""
|
||||
tenant_id = str(uuid4())
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
loader.load_storage_keys([])
|
||||
|
||||
def test_load_storage_keys_ignores_legacy_file_tenant_id(self):
|
||||
def test_load_storage_keys_ignores_legacy_file_tenant_id(self, db_session_with_containers: Session):
|
||||
"""Legacy file tenant_id should not override the loader tenant scope."""
|
||||
upload_file = self._create_upload_file()
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file = self._create_file(
|
||||
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
|
||||
tenant_id,
|
||||
related_id=upload_file.id,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
override_tenant_id=str(uuid4()),
|
||||
)
|
||||
|
||||
self.loader.load_storage_keys([file])
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
loader.load_storage_keys([file])
|
||||
|
||||
assert file._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_missing_file_id(self):
|
||||
"""Test with None file.related_id."""
|
||||
# Create a file with valid parameters first, then manually set related_id to None
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
def test_load_storage_keys_missing_file_id(self, db_session_with_containers: Session):
|
||||
"""Test with None file.related_id — should raise ValueError."""
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = None
|
||||
|
||||
# Should raise ValueError for None file related_id
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file])
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
with pytest.raises(ValueError, match="file id should not be None."):
|
||||
loader.load_storage_keys([file])
|
||||
|
||||
assert str(context.value) == "file id should not be None."
|
||||
def test_load_storage_keys_nonexistent_upload_file_records(self, db_session_with_containers: Session):
|
||||
"""Test with missing UploadFile database records — should raise ValueError."""
|
||||
tenant_id = str(uuid4())
|
||||
file = self._create_file(tenant_id, related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
def test_load_storage_keys_nonexistent_upload_file_records(self):
|
||||
"""Test with missing UploadFile database records."""
|
||||
# Create file with non-existent upload file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_nonexistent_tool_file_records(self):
|
||||
"""Test with missing ToolFile database records."""
|
||||
# Create file with non-existent tool file id
|
||||
non_existent_id = str(uuid4())
|
||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
def test_load_storage_keys_nonexistent_tool_file_records(self, db_session_with_containers: Session):
|
||||
"""Test with missing ToolFile database records — should raise ValueError."""
|
||||
tenant_id = str(uuid4())
|
||||
file = self._create_file(tenant_id, related_id=str(uuid4()), transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# Should raise ValueError for missing record
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_invalid_uuid(self):
|
||||
"""Test with invalid UUID format."""
|
||||
# Create a file with valid parameters first, then manually set invalid related_id
|
||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
def test_load_storage_keys_invalid_uuid(self, db_session_with_containers: Session):
|
||||
"""Test with invalid UUID format — should raise ValueError."""
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file.related_id = "invalid-uuid-format"
|
||||
|
||||
# Should raise ValueError for invalid UUID
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
self.loader.load_storage_keys([file])
|
||||
loader.load_storage_keys([file])
|
||||
|
||||
def test_load_storage_keys_batch_efficiency(self):
|
||||
"""Test batched operations use efficient queries."""
|
||||
# Create multiple files of different types
|
||||
upload_files = [self._create_upload_file() for _ in range(3)]
|
||||
tool_files = [self._create_tool_file() for _ in range(2)]
|
||||
def test_load_storage_keys_batch_efficiency(self, db_session_with_containers: Session):
|
||||
"""Batched operations should issue exactly 2 queries for mixed file types."""
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
|
||||
files = []
|
||||
files.extend(
|
||||
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
|
||||
upload_files = [self._create_upload_file(db_session_with_containers, tenant_id, user_id) for _ in range(3)]
|
||||
tool_files = [
|
||||
self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id) for _ in range(2)
|
||||
]
|
||||
|
||||
files = [
|
||||
self._create_file(tenant_id, related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
for uf in upload_files
|
||||
] + [
|
||||
self._create_file(tenant_id, related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||
for tf in tool_files
|
||||
]
|
||||
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
files.extend(
|
||||
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
|
||||
)
|
||||
|
||||
# Mock the session to count queries
|
||||
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
|
||||
self.loader.load_storage_keys(files)
|
||||
|
||||
# Should make exactly 2 queries (one for upload_files, one for tool_files)
|
||||
with patch.object(
|
||||
db_session_with_containers, "scalars", wraps=db_session_with_containers.scalars
|
||||
) as mock_scalars:
|
||||
loader.load_storage_keys(files)
|
||||
# Exactly 2 DB round-trips: one for UploadFile, one for ToolFile
|
||||
assert mock_scalars.call_count == 2
|
||||
|
||||
# Verify all storage keys were loaded correctly
|
||||
for i, file in enumerate(files[:3]):
|
||||
assert file._storage_key == upload_files[i].key
|
||||
for i, file in enumerate(files[3:]):
|
||||
assert file._storage_key == tool_files[i].file_key
|
||||
|
||||
def test_load_storage_keys_tenant_isolation(self):
|
||||
"""Test that tenant isolation works correctly."""
|
||||
# Create files for different tenants
|
||||
def test_load_storage_keys_tenant_isolation(self, db_session_with_containers: Session):
|
||||
"""Loader should not surface records belonging to a different tenant."""
|
||||
tenant_id = str(uuid4())
|
||||
other_tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
# Create upload file for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
upload_file_current = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
tenant_id, related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create upload file for other tenant (but don't add to cleanup list)
|
||||
upload_file_other = UploadFile(
|
||||
tenant_id=other_tenant_id,
|
||||
storage_type=StorageType.LOCAL,
|
||||
key="other_tenant_key",
|
||||
name="other_file.txt",
|
||||
size=1024,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=self.user_id,
|
||||
created_at=datetime.now(UTC),
|
||||
used=False,
|
||||
upload_file_other = self._create_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id,
|
||||
user_id,
|
||||
override_tenant_id=other_tenant_id,
|
||||
)
|
||||
upload_file_other.id = str(uuid4())
|
||||
self.session.add(upload_file_other)
|
||||
self.session.flush()
|
||||
|
||||
# Create file for other tenant but try to load with current tenant's loader
|
||||
file_other = self._create_file(
|
||||
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
tenant_id,
|
||||
related_id=upload_file_other.id,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
override_tenant_id=other_tenant_id,
|
||||
)
|
||||
|
||||
# Should raise ValueError due to tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_other])
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
|
||||
assert "Upload file not found for id:" in str(context.value)
|
||||
with pytest.raises(ValueError, match="Upload file not found for id:"):
|
||||
loader.load_storage_keys([file_other])
|
||||
|
||||
# Current tenant's file should still work
|
||||
self.loader.load_storage_keys([file_current])
|
||||
# Current-tenant file still resolves correctly
|
||||
loader.load_storage_keys([file_current])
|
||||
assert file_current._storage_key == upload_file_current.key
|
||||
|
||||
def test_load_storage_keys_mixed_tenant_batch(self):
|
||||
"""Test batch with mixed tenant files (should fail on first mismatch)."""
|
||||
# Create files for current tenant
|
||||
upload_file_current = self._create_upload_file()
|
||||
def test_load_storage_keys_mixed_tenant_batch(self, db_session_with_containers: Session):
|
||||
"""A batch containing a foreign-tenant file should fail on the mismatch."""
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
upload_file_current = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file_current = self._create_file(
|
||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
tenant_id, related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||
)
|
||||
|
||||
# Create file for different tenant
|
||||
other_tenant_id = str(uuid4())
|
||||
file_other = self._create_file(
|
||||
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
||||
tenant_id,
|
||||
related_id=str(uuid4()),
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
override_tenant_id=str(uuid4()),
|
||||
)
|
||||
|
||||
# Should raise ValueError on tenant mismatch
|
||||
with pytest.raises(ValueError) as context:
|
||||
self.loader.load_storage_keys([file_current, file_other])
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
with pytest.raises(ValueError, match="Upload file not found for id:"):
|
||||
loader.load_storage_keys([file_current, file_other])
|
||||
|
||||
assert "Upload file not found for id:" in str(context.value)
|
||||
def test_load_storage_keys_duplicate_file_ids(self, db_session_with_containers: Session):
|
||||
"""Duplicate file IDs in the batch should be handled gracefully."""
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
def test_load_storage_keys_duplicate_file_ids(self):
|
||||
"""Test handling of duplicate file IDs in the batch."""
|
||||
# Create upload file
|
||||
upload_file = self._create_upload_file()
|
||||
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file1 = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# Create two File objects with same related_id
|
||||
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
loader = StorageKeyLoader(
|
||||
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||
)
|
||||
loader.load_storage_keys([file1, file2])
|
||||
|
||||
# Should handle duplicates gracefully
|
||||
self.loader.load_storage_keys([file1, file2])
|
||||
|
||||
# Both files should have the same storage key
|
||||
assert file1._storage_key == upload_file.key
|
||||
assert file2._storage_key == upload_file.key
|
||||
|
||||
def test_load_storage_keys_session_isolation(self):
|
||||
"""Test that the loader uses the provided session correctly."""
|
||||
# Create test data
|
||||
upload_file = self._create_upload_file()
|
||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
def test_load_storage_keys_session_isolation(self, db_session_with_containers: Session):
|
||||
"""A loader backed by an uncommitted session should not see data from another session."""
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
|
||||
# Create loader with different session (same underlying connection)
|
||||
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||
|
||||
# A loader with a fresh, separate session cannot see uncommitted rows from db_session_with_containers
|
||||
with Session(bind=db.engine) as other_session:
|
||||
other_loader = StorageKeyLoader(
|
||||
other_session,
|
||||
self.tenant_id,
|
||||
tenant_id,
|
||||
access_controller=DatabaseFileAccessController(),
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@ -8,6 +8,7 @@ Covers real Redis 7+ sharded pub/sub interactions including:
|
||||
- Resource cleanup accounting via PUBSUB SHARDNUMSUB
|
||||
"""
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@ -356,10 +357,17 @@ class TestShardedRedisBroadcastChannelClusterIntegration:
|
||||
def _get_test_topic_name(cls) -> str:
|
||||
return f"test_sharded_cluster_topic_{uuid.uuid4()}"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_announced_ip(host: str) -> str:
|
||||
"""Resolve the container host name to a literal IP accepted by Redis cluster config."""
|
||||
return socket.getaddrinfo(host, None, type=socket.SOCK_STREAM)[0][4][0]
|
||||
|
||||
@staticmethod
|
||||
def _ensure_single_node_cluster(host: str, port: int) -> None:
|
||||
"""Bootstrap a single-node cluster using a literal IP for Redis node advertisement."""
|
||||
client = redis.Redis(host=host, port=port, decode_responses=False)
|
||||
client.config_set("cluster-announce-ip", host)
|
||||
announced_ip = TestShardedRedisBroadcastChannelClusterIntegration._resolve_announced_ip(host)
|
||||
client.config_set("cluster-announce-ip", announced_ip)
|
||||
client.config_set("cluster-announce-port", port)
|
||||
slots = client.execute_command("CLUSTER", "SLOTS")
|
||||
if not slots:
|
||||
|
||||
@ -5,6 +5,7 @@ from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@ -31,7 +32,7 @@ class TestApiKeyAuthService:
|
||||
def mock_args(self, category, provider, mock_credentials) -> dict:
|
||||
return {"category": category, "provider": provider, "credentials": mock_credentials}
|
||||
|
||||
def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False):
|
||||
def _create_binding(self, db_session: Session, *, tenant_id, category, provider, credentials=None, disabled=False):
|
||||
binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id,
|
||||
category=category,
|
||||
@ -44,7 +45,7 @@ class TestApiKeyAuthService:
|
||||
return binding
|
||||
|
||||
def test_get_provider_auth_list_success(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
|
||||
):
|
||||
self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider)
|
||||
db_session_with_containers.expire_all()
|
||||
@ -56,14 +57,16 @@ class TestApiKeyAuthService:
|
||||
assert len(tenant_results) == 1
|
||||
assert tenant_results[0].provider == provider
|
||||
|
||||
def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id):
|
||||
def test_get_provider_auth_list_empty(
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id
|
||||
):
|
||||
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
|
||||
|
||||
tenant_results = [r for r in result if r.tenant_id == tenant_id]
|
||||
assert tenant_results == []
|
||||
|
||||
def test_get_provider_auth_list_filters_disabled(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
|
||||
):
|
||||
self._create_binding(
|
||||
db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True
|
||||
@ -78,7 +81,13 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_success(
|
||||
self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
|
||||
self,
|
||||
mock_encrypter,
|
||||
mock_factory,
|
||||
flask_app_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id,
|
||||
mock_args,
|
||||
):
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
@ -97,7 +106,7 @@ class TestApiKeyAuthService:
|
||||
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_validation_failed(
|
||||
self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
|
||||
self, mock_factory, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_args
|
||||
):
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = False
|
||||
@ -112,7 +121,13 @@ class TestApiKeyAuthService:
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
def test_create_provider_auth_encrypts_api_key(
|
||||
self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args
|
||||
self,
|
||||
mock_encrypter,
|
||||
mock_factory,
|
||||
flask_app_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id,
|
||||
mock_args,
|
||||
):
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
@ -128,7 +143,13 @@ class TestApiKeyAuthService:
|
||||
mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key)
|
||||
|
||||
def test_get_auth_credentials_success(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials
|
||||
self,
|
||||
flask_app_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id,
|
||||
category,
|
||||
provider,
|
||||
mock_credentials,
|
||||
):
|
||||
self._create_binding(
|
||||
db_session_with_containers,
|
||||
@ -144,14 +165,14 @@ class TestApiKeyAuthService:
|
||||
assert result == mock_credentials
|
||||
|
||||
def test_get_auth_credentials_not_found(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
|
||||
):
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_auth_credentials_json_parsing(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
|
||||
):
|
||||
special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
|
||||
self._create_binding(
|
||||
@ -169,7 +190,7 @@ class TestApiKeyAuthService:
|
||||
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
|
||||
|
||||
def test_delete_provider_auth_success(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider
|
||||
):
|
||||
binding = self._create_binding(
|
||||
db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider
|
||||
@ -183,7 +204,9 @@ class TestApiKeyAuthService:
|
||||
remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first()
|
||||
assert remaining is None
|
||||
|
||||
def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id):
|
||||
def test_delete_provider_auth_not_found(
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id
|
||||
):
|
||||
# Should not raise when binding not found
|
||||
ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4()))
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
@ -114,7 +115,7 @@ class TestAuthIntegration:
|
||||
assert result2[0].tenant_id == tenant_id_2
|
||||
|
||||
def test_cross_tenant_access_prevention(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id_2, category
|
||||
):
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL)
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from unittest.mock import create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.storage.storage_type import StorageType
|
||||
@ -273,7 +274,9 @@ class TestDocumentServicePauseDocument:
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
def test_pause_document_waiting_state_success(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_pause_document_waiting_state_success(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pause of document in waiting state.
|
||||
|
||||
@ -310,7 +313,7 @@ class TestDocumentServicePauseDocument:
|
||||
mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True")
|
||||
|
||||
def test_pause_document_indexing_state_success(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pause of document in indexing state.
|
||||
@ -340,7 +343,9 @@ class TestDocumentServicePauseDocument:
|
||||
assert document.is_paused is True
|
||||
assert document.paused_by == mock_document_service_dependencies["user_id"]
|
||||
|
||||
def test_pause_document_parsing_state_success(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_pause_document_parsing_state_success(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pause of document in parsing state.
|
||||
|
||||
@ -367,7 +372,9 @@ class TestDocumentServicePauseDocument:
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.is_paused is True
|
||||
|
||||
def test_pause_document_completed_state_error(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_pause_document_completed_state_error(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when trying to pause completed document.
|
||||
|
||||
@ -396,7 +403,9 @@ class TestDocumentServicePauseDocument:
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.is_paused is False
|
||||
|
||||
def test_pause_document_error_state_error(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_pause_document_error_state_error(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when trying to pause document in error state.
|
||||
|
||||
@ -467,7 +476,9 @@ class TestDocumentServiceRecoverDocument:
|
||||
"recover_task": mock_task,
|
||||
}
|
||||
|
||||
def test_recover_document_paused_success(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_recover_document_paused_success(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful recovery of paused document.
|
||||
|
||||
@ -510,7 +521,9 @@ class TestDocumentServiceRecoverDocument:
|
||||
document.dataset_id, document.id
|
||||
)
|
||||
|
||||
def test_recover_document_not_paused_error(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_recover_document_not_paused_error(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when trying to recover non-paused document.
|
||||
|
||||
@ -590,7 +603,9 @@ class TestDocumentServiceRetryDocument:
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
def test_retry_document_single_success(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_retry_document_single_success(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retry of single document.
|
||||
|
||||
@ -629,7 +644,9 @@ class TestDocumentServiceRetryDocument:
|
||||
dataset.id, [document.id], mock_document_service_dependencies["user_id"]
|
||||
)
|
||||
|
||||
def test_retry_document_multiple_success(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_retry_document_multiple_success(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retry of multiple documents.
|
||||
|
||||
@ -675,7 +692,7 @@ class TestDocumentServiceRetryDocument:
|
||||
)
|
||||
|
||||
def test_retry_document_concurrent_retry_error(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when document is already being retried.
|
||||
@ -708,7 +725,7 @@ class TestDocumentServiceRetryDocument:
|
||||
assert document.indexing_status == IndexingStatus.ERROR
|
||||
|
||||
def test_retry_document_missing_current_user_error(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when current_user is missing.
|
||||
@ -794,7 +811,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
}
|
||||
|
||||
def test_batch_update_document_status_enable_success(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful batch enabling of documents.
|
||||
@ -844,7 +861,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
assert mock_document_service_dependencies["add_task"].delay.call_count == 2
|
||||
|
||||
def test_batch_update_document_status_disable_success(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful batch disabling of documents.
|
||||
@ -886,7 +903,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id)
|
||||
|
||||
def test_batch_update_document_status_archive_success(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful batch archiving of documents.
|
||||
@ -928,7 +945,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id)
|
||||
|
||||
def test_batch_update_document_status_unarchive_success(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful batch unarchiving of documents.
|
||||
@ -970,7 +987,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id)
|
||||
|
||||
def test_batch_update_document_status_empty_list(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of empty document list.
|
||||
@ -996,7 +1013,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus:
|
||||
mock_document_service_dependencies["remove_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_document_status_document_indexing_error(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when document is being indexed.
|
||||
@ -1073,7 +1090,7 @@ class TestDocumentServiceRenameDocument:
|
||||
"current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def test_rename_document_success(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_rename_document_success(self, db_session_with_containers: Session, mock_document_service_dependencies):
|
||||
"""
|
||||
Test successful document renaming.
|
||||
|
||||
@ -1111,7 +1128,9 @@ class TestDocumentServiceRenameDocument:
|
||||
assert result == document
|
||||
assert document.name == new_name
|
||||
|
||||
def test_rename_document_with_built_in_fields(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_rename_document_with_built_in_fields(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test document renaming with built-in fields enabled.
|
||||
|
||||
@ -1154,7 +1173,9 @@ class TestDocumentServiceRenameDocument:
|
||||
assert document.doc_metadata["document_name"] == new_name
|
||||
assert document.doc_metadata["existing_key"] == "existing_value"
|
||||
|
||||
def test_rename_document_with_upload_file(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_rename_document_with_upload_file(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test document renaming with associated upload file.
|
||||
|
||||
@ -1202,7 +1223,7 @@ class TestDocumentServiceRenameDocument:
|
||||
assert upload_file.name == new_name
|
||||
|
||||
def test_rename_document_dataset_not_found_error(
|
||||
self, db_session_with_containers, mock_document_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when dataset is not found.
|
||||
@ -1224,7 +1245,9 @@ class TestDocumentServiceRenameDocument:
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
def test_rename_document_not_found_error(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_rename_document_not_found_error(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when document is not found.
|
||||
|
||||
@ -1251,7 +1274,9 @@ class TestDocumentServiceRenameDocument:
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
DocumentService.rename_document(dataset.id, document_id, new_name)
|
||||
|
||||
def test_rename_document_permission_error(self, db_session_with_containers, mock_document_service_dependencies):
|
||||
def test_rename_document_permission_error(
|
||||
self, db_session_with_containers: Session, mock_document_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when user lacks permission.
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from redis import RedisError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import TenantAccountJoin
|
||||
@ -122,7 +123,7 @@ class TestSyncAccountDeletion:
|
||||
mock_queue_task.assert_not_called()
|
||||
|
||||
def test_sync_account_deletion_multiple_workspaces(
|
||||
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
|
||||
):
|
||||
account_id = str(uuid4())
|
||||
tenant_ids = [str(uuid4()) for _ in range(3)]
|
||||
@ -144,7 +145,7 @@ class TestSyncAccountDeletion:
|
||||
assert queued_workspace_ids == set(tenant_ids)
|
||||
|
||||
def test_sync_account_deletion_no_workspaces(
|
||||
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
|
||||
):
|
||||
with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
@ -155,7 +156,7 @@ class TestSyncAccountDeletion:
|
||||
mock_queue_task.assert_not_called()
|
||||
|
||||
def test_sync_account_deletion_partial_failure(
|
||||
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
|
||||
):
|
||||
account_id = str(uuid4())
|
||||
tenant_ids = [str(uuid4()) for _ in range(3)]
|
||||
@ -180,7 +181,7 @@ class TestSyncAccountDeletion:
|
||||
assert mock_queue_task.call_count == 3
|
||||
|
||||
def test_sync_account_deletion_all_failures(
|
||||
self, flask_app_with_containers, db_session_with_containers, mock_queue_task
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task
|
||||
):
|
||||
account_id = str(uuid4())
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import TenantPluginPermission
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
|
||||
def _tenant_id() -> str:
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def _get_permission(session: Session, tenant_id: str) -> TenantPluginPermission | None:
|
||||
session.expire_all()
|
||||
stmt = select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)
|
||||
return session.scalars(stmt).one_or_none()
|
||||
|
||||
|
||||
def _count_permissions(session: Session, tenant_id: str) -> int:
|
||||
stmt = select(func.count()).select_from(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)
|
||||
return session.scalar(stmt) or 0
|
||||
|
||||
|
||||
class TestGetPermission:
|
||||
"""Integration tests for PluginPermissionService.get_permission using testcontainers."""
|
||||
|
||||
def test_returns_permission_when_found(self, db_session_with_containers: Session):
|
||||
tenant_id = _tenant_id()
|
||||
permission = TenantPluginPermission(
|
||||
tenant_id=tenant_id,
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
db_session_with_containers.add(permission)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = PluginPermissionService.get_permission(tenant_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == permission.id
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
def test_returns_none_when_not_found(self, db_session_with_containers: Session):
|
||||
result = PluginPermissionService.get_permission(_tenant_id())
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestChangePermission:
|
||||
"""Integration tests for PluginPermissionService.change_permission using testcontainers."""
|
||||
|
||||
def test_creates_new_permission_when_not_exists(self, db_session_with_containers: Session):
|
||||
tenant_id = _tenant_id()
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
tenant_id,
|
||||
TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
|
||||
permission = _get_permission(db_session_with_containers, tenant_id)
|
||||
assert result is True
|
||||
assert permission is not None
|
||||
assert permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
|
||||
assert permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
def test_updates_existing_permission(self, db_session_with_containers: Session):
|
||||
tenant_id = _tenant_id()
|
||||
existing = TenantPluginPermission(
|
||||
tenant_id=tenant_id,
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
db_session_with_containers.add(existing)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
tenant_id,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
|
||||
permission = _get_permission(db_session_with_containers, tenant_id)
|
||||
assert result is True
|
||||
assert permission is not None
|
||||
assert permission.id == existing.id
|
||||
assert permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
assert _count_permissions(db_session_with_containers, tenant_id) == 1
|
||||
@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import App, RecommendedApp, Site
|
||||
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
@ -91,7 +93,7 @@ class TestDatabaseRecommendAppRetrieval:
|
||||
|
||||
|
||||
class TestFetchRecommendedAppsFromDb:
|
||||
def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
|
||||
_create_site(db_session_with_containers, app_id=app1.id)
|
||||
@ -111,7 +113,9 @@ class TestFetchRecommendedAppsFromDb:
|
||||
assert "assistant" in result["categories"]
|
||||
assert "writing" in result["categories"]
|
||||
|
||||
def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_falls_back_to_default_language_when_empty(
|
||||
self, flask_app_with_containers, db_session_with_containers: Session
|
||||
):
|
||||
tenant_id = str(uuid4())
|
||||
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
|
||||
_create_site(db_session_with_containers, app_id=app1.id)
|
||||
@ -124,7 +128,7 @@ class TestFetchRecommendedAppsFromDb:
|
||||
app_ids = {r["app_id"] for r in result["recommended_apps"]}
|
||||
assert app1.id in app_ids
|
||||
|
||||
def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False)
|
||||
_create_site(db_session_with_containers, app_id=app1.id)
|
||||
@ -137,7 +141,7 @@ class TestFetchRecommendedAppsFromDb:
|
||||
app_ids = {r["app_id"] for r in result["recommended_apps"]}
|
||||
assert app1.id not in app_ids
|
||||
|
||||
def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
|
||||
_create_recommended_app(db_session_with_containers, app_id=app1.id)
|
||||
@ -151,12 +155,12 @@ class TestFetchRecommendedAppsFromDb:
|
||||
|
||||
|
||||
class TestFetchRecommendedAppDetailFromDb:
|
||||
def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers: Session):
|
||||
result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False)
|
||||
_create_recommended_app(db_session_with_containers, app_id=app1.id)
|
||||
@ -168,7 +172,7 @@ class TestFetchRecommendedAppDetailFromDb:
|
||||
assert result is None
|
||||
|
||||
@patch("services.recommend_app.database.database_retrieval.AppDslService")
|
||||
def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers):
|
||||
def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
app1 = _create_app(db_session_with_containers, tenant_id=tenant_id)
|
||||
_create_site(db_session_with_containers, app_id=app1.id)
|
||||
|
||||
@ -2,6 +2,7 @@ import copy
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
@ -29,7 +30,9 @@ class TestAdvancedPromptTemplateService:
|
||||
# for consistency with other test files
|
||||
return {}
|
||||
|
||||
def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_prompt_baichuan_model_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful prompt generation for Baichuan model.
|
||||
|
||||
@ -64,7 +67,9 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_prompt_common_model_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful prompt generation for common models.
|
||||
|
||||
@ -100,7 +105,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_prompt_case_insensitive_baichuan_detection(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan model detection is case insensitive.
|
||||
@ -131,7 +136,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
|
||||
def test_get_common_prompt_chat_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for chat app with completion mode.
|
||||
@ -161,7 +166,9 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#histories#}}" in prompt_text
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_common_prompt_chat_app_chat_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for chat app with chat mode.
|
||||
|
||||
@ -189,7 +196,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_completion_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for completion app with completion mode.
|
||||
@ -217,7 +224,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_completion_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation for completion app with chat mode.
|
||||
@ -245,7 +252,9 @@ class TestAdvancedPromptTemplateService:
|
||||
assert CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_common_prompt_no_context(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation without context.
|
||||
|
||||
@ -273,7 +282,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_common_prompt_unsupported_app_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation with unsupported app mode.
|
||||
@ -291,7 +300,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert result == {}
|
||||
|
||||
def test_get_common_prompt_unsupported_model_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test common prompt generation with unsupported model mode.
|
||||
@ -308,7 +317,9 @@ class TestAdvancedPromptTemplateService:
|
||||
# Assert: Verify empty dict is returned
|
||||
assert result == {}
|
||||
|
||||
def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_completion_prompt_with_context(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test completion prompt generation with context.
|
||||
|
||||
@ -339,7 +350,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert result_text == CONTEXT + original_text
|
||||
|
||||
def test_get_completion_prompt_without_context(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test completion prompt generation without context.
|
||||
@ -368,7 +379,9 @@ class TestAdvancedPromptTemplateService:
|
||||
assert result_text == original_text
|
||||
assert CONTEXT not in result_text
|
||||
|
||||
def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_chat_prompt_with_context(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test chat prompt generation with context.
|
||||
|
||||
@ -399,7 +412,9 @@ class TestAdvancedPromptTemplateService:
|
||||
assert original_text in result_text
|
||||
assert result_text == CONTEXT + original_text
|
||||
|
||||
def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_chat_prompt_without_context(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test chat prompt generation without context.
|
||||
|
||||
@ -429,7 +444,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert CONTEXT not in result_text
|
||||
|
||||
def test_get_baichuan_prompt_chat_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for chat app with completion mode.
|
||||
@ -460,7 +475,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_chat_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for chat app with chat mode.
|
||||
@ -489,7 +504,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_completion_app_completion_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for completion app with completion mode.
|
||||
@ -517,7 +532,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_completion_app_chat_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation for completion app with chat mode.
|
||||
@ -545,7 +560,9 @@ class TestAdvancedPromptTemplateService:
|
||||
assert BAICHUAN_CONTEXT in prompt_text
|
||||
assert "{{#pre_prompt#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_baichuan_prompt_no_context(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation without context.
|
||||
|
||||
@ -573,7 +590,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert "{{#query#}}" in prompt_text
|
||||
|
||||
def test_get_baichuan_prompt_unsupported_app_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation with unsupported app mode.
|
||||
@ -591,7 +608,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert result == {}
|
||||
|
||||
def test_get_baichuan_prompt_unsupported_model_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test Baichuan prompt generation with unsupported model mode.
|
||||
@ -609,7 +626,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert result == {}
|
||||
|
||||
def test_get_prompt_all_app_modes_common_model(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prompt generation for all app modes with common model.
|
||||
@ -641,7 +658,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert result != {}
|
||||
|
||||
def test_get_prompt_all_app_modes_baichuan_model(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prompt generation for all app modes with Baichuan model.
|
||||
@ -672,7 +689,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert result is not None
|
||||
assert result != {}
|
||||
|
||||
def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_prompt_edge_cases(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test prompt generation with edge cases.
|
||||
|
||||
@ -704,7 +721,7 @@ class TestAdvancedPromptTemplateService:
|
||||
# Should either return a valid result or empty dict, but not crash
|
||||
assert result is not None
|
||||
|
||||
def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_template_immutability(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that original templates are not modified.
|
||||
|
||||
@ -738,7 +755,9 @@ class TestAdvancedPromptTemplateService:
|
||||
assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_baichuan_template_immutability(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that original Baichuan templates are not modified.
|
||||
|
||||
@ -772,7 +791,9 @@ class TestAdvancedPromptTemplateService:
|
||||
assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_context_integration_consistency(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test consistency of context integration across different scenarios.
|
||||
|
||||
@ -828,7 +849,7 @@ class TestAdvancedPromptTemplateService:
|
||||
assert prompt_text.startswith(CONTEXT)
|
||||
|
||||
def test_baichuan_context_integration_consistency(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test consistency of Baichuan context integration across different scenarios.
|
||||
|
||||
@ -3,12 +3,15 @@ from __future__ import annotations
|
||||
import base64
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from faker import Faker
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.trigger.constants import (
|
||||
TRIGGER_PLUGIN_NODE_TYPE,
|
||||
@ -17,7 +20,7 @@ from core.trigger.constants import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models import Account, AppMode
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, IconType
|
||||
from services import app_dsl_service
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -67,11 +70,27 @@ def _pending_yaml_content(version: str = "99.0.0") -> bytes:
|
||||
return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode()
|
||||
|
||||
|
||||
def _app_stub(**overrides: Any) -> App:
|
||||
defaults = {
|
||||
"id": str(uuid4()),
|
||||
"tenant_id": _DEFAULT_TENANT_ID,
|
||||
"mode": AppMode.WORKFLOW.value,
|
||||
"name": "n",
|
||||
"description": "d",
|
||||
"icon_type": IconType.EMOJI,
|
||||
"icon": "i",
|
||||
"icon_background": "#fff",
|
||||
"use_icon_as_answer_icon": False,
|
||||
"app_model_config": None,
|
||||
}
|
||||
return cast(App, SimpleNamespace(**(defaults | overrides)))
|
||||
|
||||
|
||||
class TestAppDslService:
|
||||
"""Integration tests for AppDslService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
@ -112,7 +131,7 @@ class TestAppDslService:
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
fake = Faker()
|
||||
with patch("services.account_service.FeatureService") as mock_account_feature_service:
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
@ -189,7 +208,7 @@ class TestAppDslService:
|
||||
|
||||
# ── Import: Validation ────────────────────────────────────────────
|
||||
|
||||
def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers):
|
||||
def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Invalid import_mode"):
|
||||
service.import_app(
|
||||
@ -198,7 +217,7 @@ class TestAppDslService:
|
||||
yaml_content="version: '0.1.0'",
|
||||
)
|
||||
|
||||
def test_import_app_missing_yaml_content(self, db_session_with_containers):
|
||||
def test_import_app_missing_yaml_content(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
@ -208,7 +227,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "yaml_content is required" in result.error
|
||||
|
||||
def test_import_app_missing_yaml_url(self, db_session_with_containers):
|
||||
def test_import_app_missing_yaml_url(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
@ -218,7 +237,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "yaml_url is required" in result.error
|
||||
|
||||
def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers):
|
||||
def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
@ -228,7 +247,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "content must be a mapping" in result.error
|
||||
|
||||
def test_import_app_version_not_str_returns_failed(self, db_session_with_containers):
|
||||
def test_import_app_version_not_str_returns_failed(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}})
|
||||
result = service.import_app(
|
||||
@ -239,7 +258,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Invalid version type" in result.error
|
||||
|
||||
def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers):
|
||||
def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
@ -249,7 +268,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Missing app data" in result.error
|
||||
|
||||
def test_import_app_yaml_error_returns_failed(self, db_session_with_containers, monkeypatch):
|
||||
def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
def bad_safe_load(_content: str):
|
||||
raise yaml.YAMLError("bad")
|
||||
|
||||
@ -264,7 +283,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert result.error.startswith("Invalid YAML format:")
|
||||
|
||||
def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers, monkeypatch):
|
||||
def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
AppDslService,
|
||||
"_create_or_update_app",
|
||||
@ -282,7 +301,7 @@ class TestAppDslService:
|
||||
|
||||
# ── Import: YAML URL ──────────────────────────────────────────────
|
||||
|
||||
def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers, monkeypatch):
|
||||
def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
app_dsl_service.ssrf_proxy,
|
||||
"get",
|
||||
@ -298,7 +317,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Error fetching YAML from URL: boom" in result.error
|
||||
|
||||
def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers, monkeypatch):
|
||||
def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
response = MagicMock()
|
||||
response.content = b""
|
||||
response.raise_for_status.return_value = None
|
||||
@ -313,7 +332,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "Empty content" in result.error
|
||||
|
||||
def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers, monkeypatch):
|
||||
def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch):
|
||||
response = MagicMock()
|
||||
response.content = b"x" * (DSL_MAX_SIZE + 1)
|
||||
response.raise_for_status.return_value = None
|
||||
@ -328,7 +347,9 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "File size exceeds" in result.error
|
||||
|
||||
def test_import_app_yaml_url_user_attachments_keeps_original_url(self, db_session_with_containers, monkeypatch):
|
||||
def test_import_app_yaml_url_user_attachments_keeps_original_url(
|
||||
self, db_session_with_containers: Session, monkeypatch
|
||||
):
|
||||
yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml"
|
||||
yaml_bytes = _pending_yaml_content()
|
||||
|
||||
@ -354,7 +375,7 @@ class TestAppDslService:
|
||||
assert result.imported_dsl_version == "99.0.0"
|
||||
assert requested_urls == [yaml_url]
|
||||
|
||||
def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers, monkeypatch):
|
||||
def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch):
|
||||
yaml_url = "https://github.com/acme/repo/blob/main/app.yml"
|
||||
raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml"
|
||||
yaml_bytes = _pending_yaml_content()
|
||||
@ -383,7 +404,7 @@ class TestAppDslService:
|
||||
|
||||
# ── Import: App ID checks ────────────────────────────────────────
|
||||
|
||||
def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers):
|
||||
def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
@ -395,7 +416,7 @@ class TestAppDslService:
|
||||
assert result.error == "App not found"
|
||||
|
||||
def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
assert app.mode == "chat"
|
||||
@ -412,7 +433,7 @@ class TestAppDslService:
|
||||
|
||||
# ── Import: Flow ──────────────────────────────────────────────────
|
||||
|
||||
def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers):
|
||||
def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.import_app(
|
||||
account=_account_mock(),
|
||||
@ -432,7 +453,7 @@ class TestAppDslService:
|
||||
assert stored is not None
|
||||
|
||||
def test_import_app_completed_uses_declared_dependencies(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
_, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
@ -466,7 +487,7 @@ class TestAppDslService:
|
||||
|
||||
@pytest.mark.parametrize("has_workflow", [True, False])
|
||||
def test_import_app_legacy_versions_extract_dependencies(
|
||||
self, db_session_with_containers, monkeypatch, has_workflow: bool
|
||||
self, db_session_with_containers: Session, monkeypatch, has_workflow: bool
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
AppDslService,
|
||||
@ -523,13 +544,13 @@ class TestAppDslService:
|
||||
|
||||
# ── Confirm Import ────────────────────────────────────────────────
|
||||
|
||||
def test_confirm_import_expired_returns_failed(self, db_session_with_containers):
|
||||
def test_confirm_import_expired_returns_failed(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.confirm_import(import_id=str(uuid4()), account=_account_mock())
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "expired" in result.error
|
||||
|
||||
def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers, monkeypatch):
|
||||
def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch):
|
||||
import_id = str(uuid4())
|
||||
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
|
||||
|
||||
@ -562,7 +583,7 @@ class TestAppDslService:
|
||||
assert result.app_id == created_app.id
|
||||
assert redis_client.get(redis_key) is None
|
||||
|
||||
def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers):
|
||||
def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers: Session):
|
||||
import_id = str(uuid4())
|
||||
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
|
||||
redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123")
|
||||
@ -572,7 +593,7 @@ class TestAppDslService:
|
||||
assert result.status == ImportStatus.FAILED
|
||||
assert "validation error" in result.error
|
||||
|
||||
def test_confirm_import_exception_returns_failed(self, db_session_with_containers):
|
||||
def test_confirm_import_exception_returns_failed(self, db_session_with_containers: Session):
|
||||
import_id = str(uuid4())
|
||||
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
|
||||
redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json")
|
||||
@ -583,13 +604,13 @@ class TestAppDslService:
|
||||
|
||||
# ── Check Dependencies ────────────────────────────────────────────
|
||||
|
||||
def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers):
|
||||
def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
app_model = SimpleNamespace(id=str(uuid4()), tenant_id=_DEFAULT_TENANT_ID)
|
||||
app_model = _app_stub()
|
||||
result = service.check_dependencies(app_model=app_model)
|
||||
assert result.leaked_dependencies == []
|
||||
|
||||
def test_check_dependencies_calls_analysis_service(self, db_session_with_containers, monkeypatch):
|
||||
def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch):
|
||||
app_id = str(uuid4())
|
||||
pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id)
|
||||
redis_client.setex(
|
||||
@ -614,10 +635,12 @@ class TestAppDslService:
|
||||
)
|
||||
|
||||
service = AppDslService(db_session_with_containers)
|
||||
result = service.check_dependencies(app_model=SimpleNamespace(id=app_id, tenant_id=_DEFAULT_TENANT_ID))
|
||||
result = service.check_dependencies(app_model=_app_stub(id=app_id))
|
||||
assert len(result.leaked_dependencies) == 1
|
||||
|
||||
def test_check_dependencies_with_real_app(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_check_dependencies_with_real_app(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}'
|
||||
@ -633,12 +656,12 @@ class TestAppDslService:
|
||||
|
||||
# ── Create/Update App ─────────────────────────────────────────────
|
||||
|
||||
def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers):
|
||||
def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="loss app mode"):
|
||||
service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock())
|
||||
|
||||
def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers, monkeypatch):
|
||||
def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch):
|
||||
fixed_now = object()
|
||||
monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now)
|
||||
|
||||
@ -656,9 +679,7 @@ class TestAppDslService:
|
||||
lambda _m: SimpleNamespace(kind="conv"),
|
||||
)
|
||||
|
||||
app = SimpleNamespace(
|
||||
id=str(uuid4()),
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
app = _app_stub(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
name="old",
|
||||
description="old-desc",
|
||||
@ -667,7 +688,6 @@ class TestAppDslService:
|
||||
icon_background="#111111",
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
service = AppDslService(db_session_with_containers)
|
||||
updated = service._create_or_update_app(
|
||||
@ -693,7 +713,7 @@ class TestAppDslService:
|
||||
assert app.icon_background == "#222222"
|
||||
assert app.updated_at is fixed_now
|
||||
|
||||
def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers):
|
||||
def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers: Session):
|
||||
account = _account_mock()
|
||||
account.current_tenant_id = None
|
||||
service = AppDslService(db_session_with_containers)
|
||||
@ -705,7 +725,7 @@ class TestAppDslService:
|
||||
)
|
||||
|
||||
def test_create_or_update_app_creates_workflow_app_and_saves_dependencies(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
_, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
@ -741,42 +761,26 @@ class TestAppDslService:
|
||||
stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}")
|
||||
assert stored is not None
|
||||
|
||||
def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers):
|
||||
def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Missing workflow data"):
|
||||
service._create_or_update_app(
|
||||
app=SimpleNamespace(
|
||||
id=str(uuid4()),
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
name="n",
|
||||
description="d",
|
||||
icon_background="#fff",
|
||||
app_model_config=None,
|
||||
),
|
||||
app=_app_stub(mode=AppMode.WORKFLOW.value),
|
||||
data={"app": {"mode": AppMode.WORKFLOW.value}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers):
|
||||
def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Missing model_config"):
|
||||
service._create_or_update_app(
|
||||
app=SimpleNamespace(
|
||||
id=str(uuid4()),
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
mode=AppMode.CHAT.value,
|
||||
name="n",
|
||||
description="d",
|
||||
icon_background="#fff",
|
||||
app_model_config=None,
|
||||
),
|
||||
app=_app_stub(mode=AppMode.CHAT.value),
|
||||
data={"app": {"mode": AppMode.CHAT.value}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
def test_create_or_update_app_chat_creates_model_config_and_sends_event(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
app.app_model_config_id = None
|
||||
@ -795,19 +799,11 @@ class TestAppDslService:
|
||||
db_session_with_containers.expire_all()
|
||||
assert app.app_model_config_id is not None
|
||||
|
||||
def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers):
|
||||
def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers: Session):
|
||||
service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Invalid app mode"):
|
||||
service._create_or_update_app(
|
||||
app=SimpleNamespace(
|
||||
id=str(uuid4()),
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
mode=AppMode.RAG_PIPELINE.value,
|
||||
name="n",
|
||||
description="d",
|
||||
icon_background="#fff",
|
||||
app_model_config=None,
|
||||
),
|
||||
app=_app_stub(mode=AppMode.RAG_PIPELINE.value),
|
||||
data={"app": {"mode": AppMode.RAG_PIPELINE.value}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
@ -828,29 +824,16 @@ class TestAppDslService:
|
||||
lambda *_args, **_kwargs: model_calls.append(True),
|
||||
)
|
||||
|
||||
workflow_app = SimpleNamespace(
|
||||
workflow_app = _app_stub(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
name="n",
|
||||
icon="i",
|
||||
icon_type="emoji",
|
||||
icon_background="#fff",
|
||||
description="d",
|
||||
use_icon_as_answer_icon=False,
|
||||
app_model_config=None,
|
||||
)
|
||||
AppDslService.export_dsl(workflow_app)
|
||||
assert workflow_calls == [True]
|
||||
|
||||
chat_app = SimpleNamespace(
|
||||
chat_app = _app_stub(
|
||||
mode=AppMode.CHAT.value,
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
name="n",
|
||||
icon="i",
|
||||
icon_type="emoji",
|
||||
icon_background="#fff",
|
||||
description="d",
|
||||
use_icon_as_answer_icon=False,
|
||||
app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}),
|
||||
)
|
||||
AppDslService.export_dsl(chat_app)
|
||||
@ -863,16 +846,14 @@ class TestAppDslService:
|
||||
lambda **_kwargs: None,
|
||||
)
|
||||
|
||||
emoji_app = SimpleNamespace(
|
||||
emoji_app = _app_stub(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
name="Emoji App",
|
||||
icon="🎨",
|
||||
icon_type=IconType.EMOJI,
|
||||
icon_background="#FF5733",
|
||||
description="App with emoji icon",
|
||||
use_icon_as_answer_icon=True,
|
||||
app_model_config=None,
|
||||
)
|
||||
yaml_output = AppDslService.export_dsl(emoji_app)
|
||||
data = yaml.safe_load(yaml_output)
|
||||
@ -880,16 +861,14 @@ class TestAppDslService:
|
||||
assert data["app"]["icon_type"] == "emoji"
|
||||
assert data["app"]["icon_background"] == "#FF5733"
|
||||
|
||||
image_app = SimpleNamespace(
|
||||
image_app = _app_stub(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
name="Image App",
|
||||
icon="https://example.com/icon.png",
|
||||
icon_type=IconType.IMAGE,
|
||||
icon_background="#FFEAD5",
|
||||
description="App with image icon",
|
||||
use_icon_as_answer_icon=False,
|
||||
app_model_config=None,
|
||||
)
|
||||
yaml_output = AppDslService.export_dsl(image_app)
|
||||
data = yaml.safe_load(yaml_output)
|
||||
@ -897,7 +876,7 @@ class TestAppDslService:
|
||||
assert data["app"]["icon_type"] == "image"
|
||||
assert data["app"]["icon_background"] == "#FFEAD5"
|
||||
|
||||
def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_export_dsl_chat_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
model_config = AppModelConfig(
|
||||
@ -935,7 +914,9 @@ class TestAppDslService:
|
||||
assert "model_config" in exported_data
|
||||
assert "dependencies" in exported_data
|
||||
|
||||
def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_export_dsl_workflow_app_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
app.mode = "workflow"
|
||||
db_session_with_containers.commit()
|
||||
@ -968,7 +949,9 @@ class TestAppDslService:
|
||||
assert "workflow" in exported_data
|
||||
assert "dependencies" in exported_data
|
||||
|
||||
def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_export_dsl_with_workflow_id_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
app.mode = "workflow"
|
||||
db_session_with_containers.commit()
|
||||
@ -1008,7 +991,7 @@ class TestAppDslService:
|
||||
assert "workflow" in exported_data
|
||||
|
||||
def test_export_dsl_with_invalid_workflow_id_raises_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
app.mode = "workflow"
|
||||
@ -1106,7 +1089,7 @@ class TestAppDslService:
|
||||
export_data: dict = {}
|
||||
AppDslService._append_workflow_export_data(
|
||||
export_data=export_data,
|
||||
app_model=SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID),
|
||||
app_model=_app_stub(),
|
||||
include_secret=False,
|
||||
workflow_id=None,
|
||||
)
|
||||
@ -1132,7 +1115,7 @@ class TestAppDslService:
|
||||
with pytest.raises(ValueError, match="Missing draft workflow configuration"):
|
||||
AppDslService._append_workflow_export_data(
|
||||
export_data={},
|
||||
app_model=SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID),
|
||||
app_model=_app_stub(),
|
||||
include_secret=False,
|
||||
workflow_id=None,
|
||||
)
|
||||
@ -1160,7 +1143,7 @@ class TestAppDslService:
|
||||
monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x)
|
||||
|
||||
app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}})
|
||||
app_model = SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID, app_model_config=app_model_config)
|
||||
app_model = _app_stub(app_model_config=app_model_config)
|
||||
export_data: dict = {}
|
||||
|
||||
AppDslService._append_model_config_export_data(export_data, app_model)
|
||||
@ -1169,7 +1152,7 @@ class TestAppDslService:
|
||||
|
||||
def test_append_model_config_export_data_requires_app_config(self):
|
||||
with pytest.raises(ValueError, match="Missing app configuration"):
|
||||
AppDslService._append_model_config_export_data({}, SimpleNamespace(app_model_config=None))
|
||||
AppDslService._append_model_config_export_data({}, _app_stub(app_model_config=None))
|
||||
|
||||
# ── Dependency Extraction ─────────────────────────────────────────
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import App
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -184,7 +185,7 @@ class TestAppGenerateService:
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_workflow(self, db_session_with_containers: Session, app):
|
||||
def _create_test_workflow(self, db_session_with_containers: Session, app: App):
|
||||
"""
|
||||
Helper method to create a test workflow for testing.
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.attachment_service as attachment_service_module
|
||||
@ -19,7 +19,7 @@ from services.attachment_service import AttachmentService
|
||||
|
||||
|
||||
class TestAttachmentService:
|
||||
def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile:
|
||||
def _create_upload_file(self, db_session_with_containers: Session, *, tenant_id: str | None = None) -> UploadFile:
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
storage_type=StorageType.OPENDAL,
|
||||
@ -60,7 +60,7 @@ class TestAttachmentService:
|
||||
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
|
||||
AttachmentService(session_factory=invalid_session_factory)
|
||||
|
||||
def test_should_return_base64_when_file_exists(self, db_session_with_containers):
|
||||
def test_should_return_base64_when_file_exists(self, db_session_with_containers: Session):
|
||||
upload_file = self._create_upload_file(db_session_with_containers)
|
||||
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
|
||||
|
||||
@ -70,7 +70,7 @@ class TestAttachmentService:
|
||||
assert result == base64.b64encode(b"binary-content").decode()
|
||||
mock_load.assert_called_once_with(upload_file.key)
|
||||
|
||||
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers):
|
||||
def test_should_raise_not_found_when_file_missing(self, db_session_with_containers: Session):
|
||||
service = AttachmentService(session_factory=sessionmaker(bind=db.engine))
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -24,7 +25,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_redis_cleanup(self, flask_app_with_containers):
|
||||
def setup_redis_cleanup(self, flask_app_with_containers: Flask):
|
||||
"""Clean up Redis cache before and after each test."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Clean up before test
|
||||
@ -56,7 +57,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
return value
|
||||
return None
|
||||
|
||||
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers: Flask):
|
||||
"""Test bulk plan retrieval when all tenants are in cache."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
@ -87,7 +88,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
# Verify API was not called
|
||||
mock_get_plan_bulk.assert_not_called()
|
||||
|
||||
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers: Flask):
|
||||
"""Test bulk plan retrieval when all tenants are not in cache."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
@ -127,7 +128,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
assert ttl_1 > 0
|
||||
assert ttl_1 <= 600 # Should be <= 600 seconds
|
||||
|
||||
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers: Flask):
|
||||
"""Test bulk plan retrieval when some tenants are in cache, some are not."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
@ -158,7 +159,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
cached_data_3 = json.loads(cached_3)
|
||||
assert cached_data_3 == missing_plan["tenant-3"]
|
||||
|
||||
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers: Flask):
|
||||
"""Test fallback to API when Redis mget fails."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
@ -189,7 +190,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
assert cached_1 is not None
|
||||
assert cached_2 is not None
|
||||
|
||||
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers: Flask):
|
||||
"""Test fallback to API when cache contains invalid JSON."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
@ -241,7 +242,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
cached_data_3 = json.loads(cached_3)
|
||||
assert cached_data_3 == expected_plans["tenant-3"]
|
||||
|
||||
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers: Flask):
|
||||
"""Test fallback to API when cache data doesn't match SubscriptionPlan schema."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
@ -274,7 +275,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
# Verify API was called for tenant-2 and tenant-3
|
||||
mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"])
|
||||
|
||||
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers: Flask):
|
||||
"""Test that pipeline failure doesn't affect return value."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
@ -303,7 +304,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
# Verify pipeline was attempted
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers: Flask):
|
||||
"""Test with empty tenant_ids list."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Act
|
||||
@ -321,7 +322,7 @@ class TestBillingServiceGetPlanBulkWithCache:
|
||||
# But we should check that mget was not called at all
|
||||
# Since we can't easily verify this without more mocking, we just verify the result
|
||||
|
||||
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers):
|
||||
def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers: Flask):
|
||||
"""Test that expired cache keys are treated as cache misses."""
|
||||
with flask_app_with_containers.app_context():
|
||||
# Arrange
|
||||
|
||||
@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
@ -170,7 +171,7 @@ class ConversationServiceIntegrationTestDataFactory:
|
||||
class TestConversationServicePagination:
|
||||
"""Test conversation pagination operations."""
|
||||
|
||||
def test_pagination_with_non_empty_include_ids(self, db_session_with_containers):
|
||||
def test_pagination_with_non_empty_include_ids(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that non-empty include_ids filters properly.
|
||||
|
||||
@ -204,7 +205,7 @@ class TestConversationServicePagination:
|
||||
returned_ids = {conversation.id for conversation in result.data}
|
||||
assert returned_ids == {conversations[0].id, conversations[1].id}
|
||||
|
||||
def test_pagination_with_empty_exclude_ids(self, db_session_with_containers):
|
||||
def test_pagination_with_empty_exclude_ids(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that empty exclude_ids doesn't filter.
|
||||
|
||||
@ -237,7 +238,7 @@ class TestConversationServicePagination:
|
||||
# Assert
|
||||
assert len(result.data) == len(conversations)
|
||||
|
||||
def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers):
|
||||
def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that non-empty exclude_ids filters properly.
|
||||
|
||||
@ -271,7 +272,7 @@ class TestConversationServicePagination:
|
||||
returned_ids = {conversation.id for conversation in result.data}
|
||||
assert returned_ids == {conversations[2].id}
|
||||
|
||||
def test_pagination_with_sorting_descending(self, db_session_with_containers):
|
||||
def test_pagination_with_sorting_descending(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test pagination with descending sort order.
|
||||
|
||||
@ -316,7 +317,7 @@ class TestConversationServiceMessageCreation:
|
||||
within conversations.
|
||||
"""
|
||||
|
||||
def test_pagination_by_first_id_without_first_id(self, db_session_with_containers):
|
||||
def test_pagination_by_first_id_without_first_id(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test message pagination without specifying first_id.
|
||||
|
||||
@ -354,7 +355,7 @@ class TestConversationServiceMessageCreation:
|
||||
assert len(result.data) == 3 # All 3 messages returned
|
||||
assert result.has_more is False # No more messages available (3 < limit of 10)
|
||||
|
||||
def test_pagination_by_first_id_with_first_id(self, db_session_with_containers):
|
||||
def test_pagination_by_first_id_with_first_id(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test message pagination with first_id specified.
|
||||
|
||||
@ -399,7 +400,9 @@ class TestConversationServiceMessageCreation:
|
||||
assert len(result.data) == 2 # Only 2 messages returned after first_id
|
||||
assert result.has_more is False # No more messages available (2 < limit of 10)
|
||||
|
||||
def test_pagination_by_first_id_raises_error_when_first_message_not_found(self, db_session_with_containers):
|
||||
def test_pagination_by_first_id_raises_error_when_first_message_not_found(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
"""
|
||||
Test that FirstMessageNotExistsError is raised when first_id doesn't exist.
|
||||
|
||||
@ -424,7 +427,7 @@ class TestConversationServiceMessageCreation:
|
||||
limit=10,
|
||||
)
|
||||
|
||||
def test_pagination_with_has_more_flag(self, db_session_with_containers):
|
||||
def test_pagination_with_has_more_flag(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that has_more flag is correctly set when there are more messages.
|
||||
|
||||
@ -463,7 +466,7 @@ class TestConversationServiceMessageCreation:
|
||||
assert len(result.data) == limit # Extra message should be removed
|
||||
assert result.has_more is True # Flag should be set
|
||||
|
||||
def test_pagination_with_ascending_order(self, db_session_with_containers):
|
||||
def test_pagination_with_ascending_order(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test message pagination with ascending order.
|
||||
|
||||
@ -512,7 +515,7 @@ class TestConversationServiceSummarization:
|
||||
"""
|
||||
|
||||
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
|
||||
def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers):
|
||||
def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful auto-generation of conversation name.
|
||||
|
||||
@ -552,7 +555,7 @@ class TestConversationServiceSummarization:
|
||||
app_model.tenant_id, first_message.query, conversation.id, app_model.id
|
||||
)
|
||||
|
||||
def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers):
|
||||
def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that MessageNotExistsError is raised when conversation has no messages.
|
||||
|
||||
@ -571,7 +574,9 @@ class TestConversationServiceSummarization:
|
||||
ConversationService.auto_generate_name(app_model, conversation)
|
||||
|
||||
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
|
||||
def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_llm_generator, db_session_with_containers):
|
||||
def test_auto_generate_name_handles_llm_failure_gracefully(
|
||||
self, mock_llm_generator, db_session_with_containers: Session
|
||||
):
|
||||
"""
|
||||
Test that LLM generation failures are suppressed and don't crash.
|
||||
|
||||
@ -604,7 +609,7 @@ class TestConversationServiceSummarization:
|
||||
assert conversation.name == original_name # Name remains unchanged
|
||||
|
||||
@patch("services.conversation_service.naive_utc_now")
|
||||
def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers):
|
||||
def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers: Session):
|
||||
"""
|
||||
Test renaming conversation with manual name.
|
||||
|
||||
@ -638,7 +643,7 @@ class TestConversationServiceSummarization:
|
||||
assert conversation.updated_at == mock_time
|
||||
|
||||
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
|
||||
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers):
|
||||
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers: Session):
|
||||
"""
|
||||
Test rename delegates to auto_generate_name when auto_generate is True.
|
||||
|
||||
@ -682,7 +687,9 @@ class TestConversationServiceMessageAnnotation:
|
||||
|
||||
@patch("services.annotation_service.add_annotation_to_index_task")
|
||||
@patch("services.annotation_service.current_account_with_tenant")
|
||||
def test_create_annotation_from_message(self, mock_current_account, mock_add_task, db_session_with_containers):
|
||||
def test_create_annotation_from_message(
|
||||
self, mock_current_account, mock_add_task, db_session_with_containers: Session
|
||||
):
|
||||
"""
|
||||
Test creating annotation from existing message.
|
||||
|
||||
@ -721,7 +728,9 @@ class TestConversationServiceMessageAnnotation:
|
||||
|
||||
@patch("services.annotation_service.add_annotation_to_index_task")
|
||||
@patch("services.annotation_service.current_account_with_tenant")
|
||||
def test_create_annotation_without_message(self, mock_current_account, mock_add_task, db_session_with_containers):
|
||||
def test_create_annotation_without_message(
|
||||
self, mock_current_account, mock_add_task, db_session_with_containers: Session
|
||||
):
|
||||
"""
|
||||
Test creating standalone annotation without message.
|
||||
|
||||
@ -753,7 +762,7 @@ class TestConversationServiceMessageAnnotation:
|
||||
|
||||
@patch("services.annotation_service.add_annotation_to_index_task")
|
||||
@patch("services.annotation_service.current_account_with_tenant")
|
||||
def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers):
|
||||
def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers: Session):
|
||||
"""
|
||||
Test updating an existing annotation.
|
||||
|
||||
@ -800,7 +809,7 @@ class TestConversationServiceMessageAnnotation:
|
||||
mock_add_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.annotation_service.current_account_with_tenant")
|
||||
def test_get_annotation_list(self, mock_current_account, db_session_with_containers):
|
||||
def test_get_annotation_list(self, mock_current_account, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieving paginated annotation list.
|
||||
|
||||
@ -836,7 +845,7 @@ class TestConversationServiceMessageAnnotation:
|
||||
assert result_total == 5
|
||||
|
||||
@patch("services.annotation_service.current_account_with_tenant")
|
||||
def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers):
|
||||
def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieving annotations with keyword filtering.
|
||||
|
||||
@ -885,7 +894,7 @@ class TestConversationServiceMessageAnnotation:
|
||||
|
||||
@patch("services.annotation_service.add_annotation_to_index_task")
|
||||
@patch("services.annotation_service.current_account_with_tenant")
|
||||
def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers):
|
||||
def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers: Session):
|
||||
"""
|
||||
Test direct annotation insertion without message reference.
|
||||
|
||||
@ -919,7 +928,7 @@ class TestConversationServiceExport:
|
||||
Tests retrieving conversation data for export purposes.
|
||||
"""
|
||||
|
||||
def test_get_conversation_success(self, db_session_with_containers):
|
||||
def test_get_conversation_success(self, db_session_with_containers: Session):
|
||||
"""Test successful retrieval of conversation."""
|
||||
# Arrange
|
||||
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
@ -937,7 +946,7 @@ class TestConversationServiceExport:
|
||||
# Assert
|
||||
assert result == conversation
|
||||
|
||||
def test_get_conversation_not_found(self, db_session_with_containers):
|
||||
def test_get_conversation_not_found(self, db_session_with_containers: Session):
|
||||
"""Test ConversationNotExistsError when conversation doesn't exist."""
|
||||
# Arrange
|
||||
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
@ -949,7 +958,7 @@ class TestConversationServiceExport:
|
||||
ConversationService.get_conversation(app_model=app_model, conversation_id=str(uuid4()), user=user)
|
||||
|
||||
@patch("services.annotation_service.current_account_with_tenant")
|
||||
def test_export_annotation_list(self, mock_current_account, db_session_with_containers):
|
||||
def test_export_annotation_list(self, mock_current_account, db_session_with_containers: Session):
|
||||
"""Test exporting all annotations for an app."""
|
||||
# Arrange
|
||||
app_model, account = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
@ -977,7 +986,7 @@ class TestConversationServiceExport:
|
||||
# Assert
|
||||
assert len(result) == 10
|
||||
|
||||
def test_get_message_success(self, db_session_with_containers):
|
||||
def test_get_message_success(self, db_session_with_containers: Session):
|
||||
"""Test successful retrieval of a message."""
|
||||
# Arrange
|
||||
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
@ -1001,7 +1010,7 @@ class TestConversationServiceExport:
|
||||
# Assert
|
||||
assert result == message
|
||||
|
||||
def test_get_message_not_found(self, db_session_with_containers):
|
||||
def test_get_message_not_found(self, db_session_with_containers: Session):
|
||||
"""Test MessageNotExistsError when message doesn't exist."""
|
||||
# Arrange
|
||||
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
@ -1012,7 +1021,7 @@ class TestConversationServiceExport:
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_message(app_model=app_model, user=user, message_id=str(uuid4()))
|
||||
|
||||
def test_get_conversation_for_end_user(self, db_session_with_containers):
|
||||
def test_get_conversation_for_end_user(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieving conversation created by end user via API.
|
||||
|
||||
@ -1038,7 +1047,7 @@ class TestConversationServiceExport:
|
||||
assert result == conversation
|
||||
|
||||
@patch("services.conversation_service.delete_conversation_related_data")
|
||||
def test_delete_conversation(self, mock_delete_task, db_session_with_containers):
|
||||
def test_delete_conversation(self, mock_delete_task, db_session_with_containers: Session):
|
||||
"""
|
||||
Test conversation deletion with async cleanup.
|
||||
|
||||
@ -1071,7 +1080,7 @@ class TestConversationServiceExport:
|
||||
mock_delete_task.delay.assert_called_once_with(conversation_id)
|
||||
|
||||
@patch("services.conversation_service.delete_conversation_related_data")
|
||||
def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers):
|
||||
def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers: Session):
|
||||
"""
|
||||
Test deletion is denied when conversation belongs to a different account.
|
||||
"""
|
||||
@ -1102,7 +1111,7 @@ class TestConversationServiceExport:
|
||||
mock_delete_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.conversation_service.delete_conversation_related_data")
|
||||
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers):
|
||||
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that delete propagates exceptions and does not trigger the cleanup task.
|
||||
|
||||
|
||||
@ -5,7 +5,8 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
@ -149,7 +150,7 @@ class ConversationServiceVariableIntegrationFactory:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_conversation_service_session_factory(flask_app_with_containers):
|
||||
def real_conversation_service_session_factory(flask_app_with_containers: Flask):
|
||||
del flask_app_with_containers
|
||||
real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
@ -162,7 +163,7 @@ def real_conversation_service_session_factory(flask_app_with_containers):
|
||||
|
||||
class TestConversationServiceVariables:
|
||||
def test_get_conversational_variable_success(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
self, db_session_with_containers: Session, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
@ -200,7 +201,7 @@ class TestConversationServiceVariables:
|
||||
assert result.has_more is False
|
||||
|
||||
def test_get_conversational_variable_with_last_id(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
self, db_session_with_containers: Session, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
@ -242,7 +243,7 @@ class TestConversationServiceVariables:
|
||||
assert result.has_more is False
|
||||
|
||||
def test_get_conversational_variable_last_id_not_found_raises_error(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
self, db_session_with_containers: Session, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
@ -259,7 +260,7 @@ class TestConversationServiceVariables:
|
||||
)
|
||||
|
||||
def test_get_conversational_variable_sets_has_more(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
self, db_session_with_containers: Session, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
@ -287,7 +288,7 @@ class TestConversationServiceVariables:
|
||||
assert result.has_more is True
|
||||
|
||||
def test_update_conversation_variable_success(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
self, db_session_with_containers: Session, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
@ -320,7 +321,7 @@ class TestConversationServiceVariables:
|
||||
assert result["updated_at"] == updated_at
|
||||
|
||||
def test_update_conversation_variable_not_found_raises_error(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
self, db_session_with_containers: Session, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
@ -337,7 +338,7 @@ class TestConversationServiceVariables:
|
||||
)
|
||||
|
||||
def test_update_conversation_variable_type_mismatch_raises_error(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
self, db_session_with_containers: Session, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
@ -360,7 +361,7 @@ class TestConversationServiceVariables:
|
||||
)
|
||||
|
||||
def test_update_conversation_variable_integer_number_compatibility(
|
||||
self, db_session_with_containers, real_conversation_service_session_factory
|
||||
self, db_session_with_containers: Session, real_conversation_service_session_factory
|
||||
):
|
||||
del real_conversation_service_session_factory
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
@ -390,7 +391,7 @@ class TestConversationServiceVariables:
|
||||
|
||||
|
||||
class TestConversationServicePaginationWithContainers:
|
||||
def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers):
|
||||
def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers: Session):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
|
||||
@ -404,7 +405,7 @@ class TestConversationServicePaginationWithContainers:
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers):
|
||||
def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers: Session):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
base_time = datetime(2024, 1, 1, 8, 0, 0)
|
||||
@ -442,7 +443,7 @@ class TestConversationServicePaginationWithContainers:
|
||||
assert newest.id != middle.id
|
||||
assert [conversation.id for conversation in result.data] == [oldest.id]
|
||||
|
||||
def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers):
|
||||
def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers: Session):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha")
|
||||
@ -462,7 +463,7 @@ class TestConversationServicePaginationWithContainers:
|
||||
assert alpha.id != beta.id
|
||||
assert [conversation.id for conversation in result.data] == [gamma.id]
|
||||
|
||||
def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers):
|
||||
def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers: Session):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
end_user = factory.create_end_user(db_session_with_containers, app)
|
||||
@ -493,7 +494,7 @@ class TestConversationServicePaginationWithContainers:
|
||||
assert account_conversation.id != end_user_conversation.id
|
||||
assert [conversation.id for conversation in result.data] == [end_user_conversation.id]
|
||||
|
||||
def test_pagination_filters_to_account_console_source(self, db_session_with_containers):
|
||||
def test_pagination_filters_to_account_console_source(self, db_session_with_containers: Session):
|
||||
factory = ConversationServiceVariableIntegrationFactory
|
||||
app, account = factory.create_app_and_account(db_session_with_containers)
|
||||
end_user = factory.create_end_user(db_session_with_containers, app)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from graphon.variables import StringVariable
|
||||
@ -13,7 +13,12 @@ from services.conversation_variable_updater import ConversationVariableNotFoundE
|
||||
|
||||
class TestConversationVariableUpdater:
|
||||
def _create_conversation_variable(
|
||||
self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
*,
|
||||
conversation_id: str,
|
||||
variable: StringVariable,
|
||||
app_id: str | None = None,
|
||||
) -> ConversationVariable:
|
||||
row = ConversationVariable(
|
||||
id=variable.id,
|
||||
@ -25,7 +30,7 @@ class TestConversationVariableUpdater:
|
||||
db_session_with_containers.commit()
|
||||
return row
|
||||
|
||||
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers):
|
||||
def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers: Session):
|
||||
conversation_id = str(uuid4())
|
||||
variable = StringVariable(id=str(uuid4()), name="topic", value="old value")
|
||||
self._create_conversation_variable(
|
||||
@ -42,7 +47,7 @@ class TestConversationVariableUpdater:
|
||||
assert row is not None
|
||||
assert row.data == updated_variable.model_dump_json()
|
||||
|
||||
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers):
|
||||
def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers: Session):
|
||||
conversation_id = str(uuid4())
|
||||
variable = StringVariable(id=str(uuid4()), name="topic", value="value")
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
@ -50,7 +55,7 @@ class TestConversationVariableUpdater:
|
||||
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers):
|
||||
def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers: Session):
|
||||
updater = ConversationVariableUpdater(sessionmaker(bind=db.engine))
|
||||
|
||||
result = updater.flush()
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
@ -14,7 +15,7 @@ class TestCreditPoolService:
|
||||
def _create_tenant_id(self) -> str:
|
||||
return str(uuid4())
|
||||
|
||||
def test_create_default_pool(self, db_session_with_containers):
|
||||
def test_create_default_pool(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
@ -25,7 +26,7 @@ class TestCreditPoolService:
|
||||
assert pool.quota_used == 0
|
||||
assert pool.quota_limit > 0
|
||||
|
||||
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers):
|
||||
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
@ -35,17 +36,17 @@ class TestCreditPoolService:
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.pool_type == ProviderQuotaType.TRIAL
|
||||
|
||||
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session):
|
||||
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers):
|
||||
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session):
|
||||
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers):
|
||||
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
@ -53,7 +54,7 @@ class TestCreditPoolService:
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers):
|
||||
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
# Exhaust credits
|
||||
@ -64,11 +65,11 @@ class TestCreditPoolService:
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers):
|
||||
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session):
|
||||
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers):
|
||||
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
pool.quota_used = pool.quota_limit
|
||||
@ -77,7 +78,7 @@ class TestCreditPoolService:
|
||||
with pytest.raises(QuotaExceededError, match="No credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
|
||||
|
||||
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers):
|
||||
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
credits_required = 10
|
||||
@ -89,7 +90,7 @@ class TestCreditPoolService:
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert pool.quota_used == credits_required
|
||||
|
||||
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers):
|
||||
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
|
||||
@ -8,6 +8,7 @@ checks with testcontainers-backed infrastructure instead of database-chain mocks
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
@ -107,7 +108,7 @@ class DatasetPermissionTestDataFactory:
|
||||
class TestDatasetPermissionServiceGetPartialMemberList:
|
||||
"""Verify partial-member list reads against persisted DatasetPermission rows."""
|
||||
|
||||
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers):
|
||||
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieving partial member list with multiple members.
|
||||
"""
|
||||
@ -138,7 +139,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
|
||||
assert set(result) == set(expected_account_ids)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers):
|
||||
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieving partial member list with single member.
|
||||
"""
|
||||
@ -160,7 +161,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
|
||||
assert set(result) == set(expected_account_ids)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers):
|
||||
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieving partial member list when no members exist.
|
||||
"""
|
||||
@ -179,7 +180,7 @@ class TestDatasetPermissionServiceGetPartialMemberList:
|
||||
class TestDatasetPermissionServiceUpdatePartialMemberList:
|
||||
"""Verify partial-member list updates against persisted DatasetPermission rows."""
|
||||
|
||||
def test_update_partial_member_list_add_new_members(self, db_session_with_containers):
|
||||
def test_update_partial_member_list_add_new_members(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test adding new partial members to a dataset.
|
||||
"""
|
||||
@ -203,7 +204,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert set(result) == {member_1.id, member_2.id}
|
||||
|
||||
def test_update_partial_member_list_replace_existing(self, db_session_with_containers):
|
||||
def test_update_partial_member_list_replace_existing(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test replacing existing partial members with new ones.
|
||||
"""
|
||||
@ -239,7 +240,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert set(result) == {new_member_1.id, new_member_2.id}
|
||||
|
||||
def test_update_partial_member_list_empty_list(self, db_session_with_containers):
|
||||
def test_update_partial_member_list_empty_list(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test updating with empty member list (clearing all members).
|
||||
"""
|
||||
@ -264,7 +265,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers):
|
||||
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error handling and rollback on database error.
|
||||
"""
|
||||
@ -313,7 +314,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList:
|
||||
class TestDatasetPermissionServiceClearPartialMemberList:
|
||||
"""Verify partial-member clearing against persisted DatasetPermission rows."""
|
||||
|
||||
def test_clear_partial_member_list_success(self, db_session_with_containers):
|
||||
def test_clear_partial_member_list_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful clearing of partial member list.
|
||||
"""
|
||||
@ -338,7 +339,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_clear_partial_member_list_empty_list(self, db_session_with_containers):
|
||||
def test_clear_partial_member_list_empty_list(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test clearing partial member list when no members exist.
|
||||
"""
|
||||
@ -353,7 +354,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers):
|
||||
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error handling and rollback on database error.
|
||||
"""
|
||||
@ -398,7 +399,7 @@ class TestDatasetPermissionServiceClearPartialMemberList:
|
||||
class TestDatasetServiceCheckDatasetPermission:
|
||||
"""Verify dataset access checks against persisted partial-member permissions."""
|
||||
|
||||
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers):
|
||||
def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers: Session):
|
||||
"""Test that users from different tenants cannot access dataset."""
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
@ -410,7 +411,7 @@ class TestDatasetServiceCheckDatasetPermission:
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, other_user)
|
||||
|
||||
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers):
|
||||
def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers: Session):
|
||||
"""Test that tenant owners can access any dataset regardless of permission level."""
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
@ -423,7 +424,7 @@ class TestDatasetServiceCheckDatasetPermission:
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, owner)
|
||||
|
||||
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers):
|
||||
def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers: Session):
|
||||
"""Test ONLY_ME permission allows only the dataset creator to access."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
|
||||
@ -433,7 +434,7 @@ class TestDatasetServiceCheckDatasetPermission:
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, creator)
|
||||
|
||||
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers):
|
||||
def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers: Session):
|
||||
"""Test ONLY_ME permission denies access to non-creators."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetPermission:
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.check_dataset_permission(dataset, other)
|
||||
|
||||
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers):
|
||||
def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers: Session):
|
||||
"""Test ALL_TEAM permission allows any team member to access the dataset."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
@ -460,7 +461,9 @@ class TestDatasetServiceCheckDatasetPermission:
|
||||
|
||||
DatasetService.check_dataset_permission(dataset, member)
|
||||
|
||||
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
|
||||
def test_check_dataset_permission_partial_members_with_permission_success(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
"""
|
||||
Test that user with explicit permission can access partial_members dataset.
|
||||
"""
|
||||
@ -485,7 +488,9 @@ class TestDatasetServiceCheckDatasetPermission:
|
||||
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert user.id in permissions
|
||||
|
||||
def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers):
|
||||
def test_check_dataset_permission_partial_members_without_permission_error(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
"""
|
||||
Test error when user without permission tries to access partial_members dataset.
|
||||
"""
|
||||
@ -506,7 +511,7 @@ class TestDatasetServiceCheckDatasetPermission:
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers):
|
||||
def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers: Session):
|
||||
"""Test PARTIAL_TEAM permission allows creator to access without explicit permission."""
|
||||
creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR)
|
||||
|
||||
|
||||
@ -712,7 +712,7 @@ class TestDatasetServiceRetrievalConfiguration:
|
||||
class TestDocumentServicePauseRecoverRetry:
|
||||
"""Tests for pause/recover/retry orchestration using real DB and Redis."""
|
||||
|
||||
def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"):
|
||||
def _create_indexing_document(self, db_session_with_containers: Session, indexing_status="indexing"):
|
||||
factory = DatasetServiceIntegrationDataFactory
|
||||
account, tenant = factory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id)
|
||||
@ -721,7 +721,7 @@ class TestDocumentServicePauseRecoverRetry:
|
||||
db_session_with_containers.commit()
|
||||
return doc, account
|
||||
|
||||
def test_pause_document_success(self, db_session_with_containers):
|
||||
def test_pause_document_success(self, db_session_with_containers: Session):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
@ -740,7 +740,7 @@ class TestDocumentServicePauseRecoverRetry:
|
||||
assert redis_client.get(cache_key) is not None
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
def test_pause_document_invalid_status_error(self, db_session_with_containers):
|
||||
def test_pause_document_invalid_status_error(self, db_session_with_containers: Session):
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
|
||||
@ -751,7 +751,7 @@ class TestDocumentServicePauseRecoverRetry:
|
||||
with pytest.raises(DocumentIndexingError):
|
||||
DocumentService.pause_document(doc)
|
||||
|
||||
def test_recover_document_success(self, db_session_with_containers):
|
||||
def test_recover_document_success(self, db_session_with_containers: Session):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
@ -775,7 +775,7 @@ class TestDocumentServicePauseRecoverRetry:
|
||||
assert redis_client.get(cache_key) is None
|
||||
recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id)
|
||||
|
||||
def test_retry_document_indexing_success(self, db_session_with_containers):
|
||||
def test_retry_document_indexing_success(self, db_session_with_containers: Session):
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from services.dataset_service import DatasetService
|
||||
@ -48,7 +49,7 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers):
|
||||
def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers: Session):
|
||||
tenant, _ = self._create_tenant_and_account(db_session_with_containers)
|
||||
|
||||
mock_user = Mock(id=None)
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
@ -101,7 +103,7 @@ class DatasetDeleteIntegrationDataFactory:
|
||||
class TestDatasetServiceDeleteDataset:
|
||||
"""Integration coverage for DatasetService.delete_dataset using testcontainers."""
|
||||
|
||||
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
|
||||
def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session):
|
||||
"""Delete a dataset with documents and dispatch cleanup through the real signal handler."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
@ -144,7 +146,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
dataset.pipeline_id,
|
||||
)
|
||||
|
||||
def test_delete_empty_dataset_success(self, db_session_with_containers):
|
||||
def test_delete_empty_dataset_success(self, db_session_with_containers: Session):
|
||||
"""Delete an empty dataset without scheduling cleanup when both gating fields are absent."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
@ -172,7 +174,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
clean_dataset_delay.assert_not_called()
|
||||
|
||||
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
|
||||
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session):
|
||||
"""Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
@ -200,7 +202,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
clean_dataset_delay.assert_not_called()
|
||||
|
||||
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers):
|
||||
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers: Session):
|
||||
"""Delete a dataset without cleanup when indexing exists but doc_form resolves to None."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
@ -228,7 +230,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
clean_dataset_delay.assert_not_called()
|
||||
|
||||
def test_delete_dataset_not_found(self, db_session_with_containers):
|
||||
def test_delete_dataset_not_found(self, db_session_with_containers: Session):
|
||||
"""Return False without scheduling cleanup when the target dataset does not exist."""
|
||||
# Arrange
|
||||
owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
|
||||
@ -6,6 +6,7 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -363,7 +364,7 @@ class TestDatasetServicePermissionsAndLifecycle:
|
||||
|
||||
DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset)
|
||||
|
||||
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers):
|
||||
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers: Flask):
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DatasetService.update_dataset_api_status(str(uuid4()), True)
|
||||
@ -473,7 +474,7 @@ class TestDatasetCollectionBindingServiceIntegration:
|
||||
assert persisted.type == "dataset"
|
||||
assert persisted.collection_name
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers):
|
||||
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers: Flask):
|
||||
with flask_app_with_containers.app_context():
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4()))
|
||||
|
||||
@ -6,6 +6,7 @@ from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
@ -46,7 +47,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
db_session_with_containers.commit()
|
||||
return run
|
||||
|
||||
def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None:
|
||||
def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> None:
|
||||
archive_log = WorkflowArchiveLog(
|
||||
tenant_id=run.tenant_id,
|
||||
app_id=run.app_id,
|
||||
@ -72,7 +73,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
db_session_with_containers.add(archive_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers):
|
||||
def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers: Session):
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
missing_run_id = str(uuid4())
|
||||
|
||||
@ -81,7 +82,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
assert result.success is False
|
||||
assert result.error == f"Workflow run {missing_run_id} not found"
|
||||
|
||||
def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers):
|
||||
def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers,
|
||||
@ -95,7 +96,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
assert result.success is False
|
||||
assert result.error == f"Workflow run {run.id} is not archived"
|
||||
|
||||
def test_delete_batch_uses_repo(self, db_session_with_containers):
|
||||
def test_delete_batch_uses_repo(self, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
base_time = datetime.now(UTC)
|
||||
run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time)
|
||||
@ -124,7 +125,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
).all()
|
||||
assert remaining_runs == []
|
||||
|
||||
def test_delete_run_calls_repo(self, db_session_with_containers):
|
||||
def test_delete_run_calls_repo(self, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
run = self._create_workflow_run(
|
||||
db_session_with_containers,
|
||||
@ -142,7 +143,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
|
||||
assert deleted_run is None
|
||||
|
||||
def test_delete_run_dry_run(self, db_session_with_containers):
|
||||
def test_delete_run_dry_run(self, db_session_with_containers: Session):
|
||||
"""Dry run should return success without actually deleting."""
|
||||
tenant_id = str(uuid4())
|
||||
run = self._create_workflow_run(
|
||||
@ -161,7 +162,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
db_session_with_containers.expire_all()
|
||||
assert db_session_with_containers.get(WorkflowRun, run_id) is not None
|
||||
|
||||
def test_delete_run_exception_returns_error(self, db_session_with_containers):
|
||||
def test_delete_run_exception_returns_error(self, db_session_with_containers: Session):
|
||||
"""Exception during deletion should return failure result."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -183,7 +184,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
assert result.success is False
|
||||
assert result.error == "Database error"
|
||||
|
||||
def test_delete_by_run_id_success(self, db_session_with_containers):
|
||||
def test_delete_by_run_id_success(self, db_session_with_containers: Session):
|
||||
"""Successfully delete an archived workflow run by ID."""
|
||||
tenant_id = str(uuid4())
|
||||
base_time = datetime.now(UTC)
|
||||
@ -202,7 +203,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
db_session_with_containers.expunge_all()
|
||||
assert db_session_with_containers.get(WorkflowRun, run_id) is None
|
||||
|
||||
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers):
|
||||
def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers: Session):
|
||||
"""_get_workflow_run_repo should return a cached repo on subsequent calls."""
|
||||
deleter = ArchivedWorkflowRunDeletion()
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
@ -102,7 +103,7 @@ class TestEndUserServiceGetOrCreateEndUser:
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory):
|
||||
def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers: Session, factory):
|
||||
"""Test getting or creating end user with custom user_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -118,7 +119,7 @@ class TestEndUserServiceGetOrCreateEndUser:
|
||||
assert result.type == InvokeFrom.SERVICE_API
|
||||
assert result.is_anonymous is False
|
||||
|
||||
def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory):
|
||||
def test_get_or_create_end_user_without_user_id(self, db_session_with_containers: Session, factory):
|
||||
"""Test getting or creating end user without user_id uses default session."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -131,7 +132,7 @@ class TestEndUserServiceGetOrCreateEndUser:
|
||||
# Verify _is_anonymous is set correctly (property always returns False)
|
||||
assert result._is_anonymous is True
|
||||
|
||||
def test_get_existing_end_user(self, db_session_with_containers, factory):
|
||||
def test_get_existing_end_user(self, db_session_with_containers: Session, factory):
|
||||
"""Test retrieving an existing end user."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -167,7 +168,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
def test_create_end_user_service_api_type(self, db_session_with_containers, factory):
|
||||
def test_create_end_user_service_api_type(self, db_session_with_containers: Session, factory):
|
||||
"""Test creating new end user with SERVICE_API type."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -189,7 +190,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
assert result.app_id == app_id
|
||||
assert result.session_id == user_id
|
||||
|
||||
def test_create_end_user_web_app_type(self, db_session_with_containers, factory):
|
||||
def test_create_end_user_web_app_type(self, db_session_with_containers: Session, factory):
|
||||
"""Test creating new end user with WEB_APP type."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -209,7 +210,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
assert result.type == InvokeFrom.WEB_APP
|
||||
|
||||
@patch("services.end_user_service.logger")
|
||||
def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory):
|
||||
def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers: Session, factory):
|
||||
"""Test upgrading legacy end user with different type."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -243,7 +244,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
assert "Upgrading legacy EndUser" in log_call
|
||||
|
||||
@patch("services.end_user_service.logger")
|
||||
def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory):
|
||||
def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers: Session, factory):
|
||||
"""Test retrieving existing end user with matching type."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -272,7 +273,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
assert result.type == InvokeFrom.SERVICE_API
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory):
|
||||
def test_create_anonymous_user_with_default_session(self, db_session_with_containers: Session, factory):
|
||||
"""Test creating anonymous user when user_id is None."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -293,7 +294,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
assert result._is_anonymous is True
|
||||
assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory):
|
||||
def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers: Session, factory):
|
||||
"""Test that query ordering prioritizes records with matching type."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -328,7 +329,7 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
assert result.id == matching.id
|
||||
assert result.id != non_matching.id
|
||||
|
||||
def test_external_user_id_matches_session_id(self, db_session_with_containers, factory):
|
||||
def test_external_user_id_matches_session_id(self, db_session_with_containers: Session, factory):
|
||||
"""Test that external_user_id is set to match session_id."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -357,7 +358,9 @@ class TestEndUserServiceGetOrCreateEndUserByType:
|
||||
InvokeFrom.DEBUGGER,
|
||||
],
|
||||
)
|
||||
def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory):
|
||||
def test_create_end_user_with_different_invoke_types(
|
||||
self, db_session_with_containers: Session, invoke_type, factory
|
||||
):
|
||||
"""Test creating end users with different InvokeFrom types."""
|
||||
# Arrange
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
@ -385,7 +388,7 @@ class TestEndUserServiceGetEndUserById:
|
||||
"""Provide test data factory."""
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory):
|
||||
def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers: Session, factory):
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
existing_user = factory.create_end_user(
|
||||
db_session_with_containers,
|
||||
@ -404,7 +407,7 @@ class TestEndUserServiceGetEndUserById:
|
||||
assert result is not None
|
||||
assert result.id == existing_user.id
|
||||
|
||||
def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory):
|
||||
def test_get_end_user_by_id_returns_none(self, db_session_with_containers: Session, factory):
|
||||
app = factory.create_app_and_account(db_session_with_containers)
|
||||
|
||||
result = EndUserService.get_end_user_by_id(
|
||||
@ -423,7 +426,7 @@ class TestEndUserServiceCreateBatch:
|
||||
def factory(self):
|
||||
return TestEndUserServiceFactory()
|
||||
|
||||
def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3):
|
||||
def _create_multiple_apps(self, db_session_with_containers: Session, factory, count: int = 3):
|
||||
"""Create multiple apps under the same tenant."""
|
||||
first_app = factory.create_app_and_account(db_session_with_containers)
|
||||
tenant_id = first_app.tenant_id
|
||||
@ -452,13 +455,13 @@ class TestEndUserServiceCreateBatch:
|
||||
all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all()
|
||||
return tenant_id, all_apps
|
||||
|
||||
def test_create_batch_empty_app_ids(self, db_session_with_containers):
|
||||
def test_create_batch_empty_app_ids(self, db_session_with_containers: Session):
|
||||
result = EndUserService.create_end_user_batch(
|
||||
type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory):
|
||||
def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers: Session, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
|
||||
app_ids = [a.id for a in apps]
|
||||
user_id = f"user-{uuid4()}"
|
||||
@ -473,7 +476,7 @@ class TestEndUserServiceCreateBatch:
|
||||
assert result[app_id].session_id == user_id
|
||||
assert result[app_id].type == InvokeFrom.SERVICE_API
|
||||
|
||||
def test_create_batch_default_session_id(self, db_session_with_containers, factory):
|
||||
def test_create_batch_default_session_id(self, db_session_with_containers: Session, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
|
||||
app_ids = [a.id for a in apps]
|
||||
|
||||
@ -486,7 +489,7 @@ class TestEndUserServiceCreateBatch:
|
||||
assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
assert end_user._is_anonymous is True
|
||||
|
||||
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory):
|
||||
def test_create_batch_deduplicate_app_ids(self, db_session_with_containers: Session, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
|
||||
app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id]
|
||||
user_id = f"user-{uuid4()}"
|
||||
@ -497,7 +500,7 @@ class TestEndUserServiceCreateBatch:
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_create_batch_returns_existing_users(self, db_session_with_containers, factory):
|
||||
def test_create_batch_returns_existing_users(self, db_session_with_containers: Session, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2)
|
||||
app_ids = [a.id for a in apps]
|
||||
user_id = f"user-{uuid4()}"
|
||||
@ -516,7 +519,7 @@ class TestEndUserServiceCreateBatch:
|
||||
for app_id in app_ids:
|
||||
assert first_result[app_id].id == second_result[app_id].id
|
||||
|
||||
def test_create_batch_partial_existing_users(self, db_session_with_containers, factory):
|
||||
def test_create_batch_partial_existing_users(self, db_session_with_containers: Session, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3)
|
||||
user_id = f"user-{uuid4()}"
|
||||
|
||||
@ -545,7 +548,7 @@ class TestEndUserServiceCreateBatch:
|
||||
"invoke_type",
|
||||
[InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER],
|
||||
)
|
||||
def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory):
|
||||
def test_create_batch_all_invoke_types(self, db_session_with_containers: Session, invoke_type, factory):
|
||||
tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1)
|
||||
user_id = f"user-{uuid4()}"
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.feature_service import (
|
||||
@ -81,7 +82,7 @@ class TestFeatureService:
|
||||
fake = Faker()
|
||||
return fake.uuid4()
|
||||
|
||||
def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful feature retrieval with billing and enterprise enabled.
|
||||
|
||||
@ -156,7 +157,7 @@ class TestFeatureService:
|
||||
tenant_id
|
||||
)
|
||||
|
||||
def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_sandbox_plan(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test feature retrieval for sandbox plan with specific limitations.
|
||||
|
||||
@ -222,7 +223,9 @@ class TestFeatureService:
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_knowledge_rate_limit_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful knowledge rate limit retrieval with billing enabled.
|
||||
|
||||
@ -255,7 +258,7 @@ class TestFeatureService:
|
||||
tenant_id
|
||||
)
|
||||
|
||||
def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_system_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful system features retrieval with enterprise and marketplace enabled.
|
||||
|
||||
@ -332,7 +335,9 @@ class TestFeatureService:
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_system_features_unauthenticated(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval for an unauthenticated user.
|
||||
|
||||
@ -386,7 +391,9 @@ class TestFeatureService:
|
||||
# Marketplace should be visible
|
||||
assert result.enable_marketplace is True
|
||||
|
||||
def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_system_features_basic_config(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval with basic configuration (no enterprise).
|
||||
|
||||
@ -436,7 +443,9 @@ class TestFeatureService:
|
||||
# Verify plugin package size (uses default value from dify_config)
|
||||
assert result.max_plugin_package_size == 15728640
|
||||
|
||||
def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_billing_disabled(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval when billing is disabled.
|
||||
|
||||
@ -492,7 +501,7 @@ class TestFeatureService:
|
||||
assert result.webapp_copyright_enabled is False
|
||||
|
||||
def test_get_knowledge_rate_limit_billing_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test knowledge rate limit retrieval when billing is disabled.
|
||||
@ -523,7 +532,9 @@ class TestFeatureService:
|
||||
# Verify no billing service calls
|
||||
mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called()
|
||||
|
||||
def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_enterprise_only(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with enterprise enabled but billing disabled.
|
||||
|
||||
@ -583,7 +594,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_not_called()
|
||||
|
||||
def test_get_system_features_enterprise_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval when enterprise is disabled.
|
||||
@ -640,7 +651,7 @@ class TestFeatureService:
|
||||
# Verify no enterprise service calls
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called()
|
||||
|
||||
def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test feature retrieval without tenant ID (billing disabled).
|
||||
|
||||
@ -686,7 +697,9 @@ class TestFeatureService:
|
||||
# Verify no billing service calls
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_not_called()
|
||||
|
||||
def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_partial_billing_info(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with partial billing information.
|
||||
|
||||
@ -746,7 +759,9 @@ class TestFeatureService:
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_edge_case_vector_space(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with edge case vector space configuration.
|
||||
|
||||
@ -807,7 +822,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_system_features_edge_case_webapp_auth(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval with edge case webapp auth configuration.
|
||||
@ -863,7 +878,9 @@ class TestFeatureService:
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_edge_case_members_quota(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with edge case members quota configuration.
|
||||
|
||||
@ -924,7 +941,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_plugin_installation_permission_scopes(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval with different plugin installation permission scopes.
|
||||
@ -1023,7 +1040,7 @@ class TestFeatureService:
|
||||
assert result.plugin_installation_permission.restrict_to_marketplace_only is True
|
||||
|
||||
def test_get_features_workspace_members_missing(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval when workspace members info is missing from enterprise.
|
||||
@ -1064,7 +1081,9 @@ class TestFeatureService:
|
||||
tenant_id
|
||||
)
|
||||
|
||||
def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_system_features_license_inactive(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval with inactive license.
|
||||
|
||||
@ -1117,7 +1136,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_system_features_partial_enterprise_info(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval with partial enterprise information.
|
||||
@ -1186,7 +1205,9 @@ class TestFeatureService:
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_edge_case_limits(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with edge case limit values.
|
||||
|
||||
@ -1244,7 +1265,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_system_features_edge_case_protocols(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval with edge case protocol values.
|
||||
@ -1297,7 +1318,9 @@ class TestFeatureService:
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_features_edge_case_education(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with edge case education configuration.
|
||||
|
||||
@ -1353,7 +1376,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_license_limitation_model_is_available(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test LicenseLimitationModel.is_available method with various scenarios.
|
||||
@ -1394,7 +1417,7 @@ class TestFeatureService:
|
||||
assert exact_limit.is_available(3) is True
|
||||
|
||||
def test_get_features_workspace_members_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval when workspace members are disabled in enterprise.
|
||||
@ -1433,7 +1456,9 @@ class TestFeatureService:
|
||||
# Verify mock interactions
|
||||
mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_system_features_license_expired(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval with expired license.
|
||||
|
||||
@ -1486,7 +1511,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_features_edge_case_docs_processing(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with edge case document processing configuration.
|
||||
@ -1544,7 +1569,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_system_features_edge_case_branding(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features retrieval with edge case branding configuration.
|
||||
@ -1606,7 +1631,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_features_edge_case_annotation_quota(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with edge case annotation quota configuration.
|
||||
@ -1668,7 +1693,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_features_edge_case_documents_upload(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with edge case documents upload settings.
|
||||
@ -1733,7 +1758,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_system_features_edge_case_license_lost(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test system features with lost license status.
|
||||
@ -1784,7 +1809,7 @@ class TestFeatureService:
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_features_edge_case_education_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test feature retrieval with education feature disabled.
|
||||
|
||||
@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.human_input_adapter import (
|
||||
@ -88,7 +89,7 @@ class TestDeliveryTestRegistry:
|
||||
with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."):
|
||||
registry.dispatch(context=context, method=method)
|
||||
|
||||
def test_default(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_default(self, flask_app_with_containers, db_session_with_containers: Session):
|
||||
registry = DeliveryTestRegistry.default()
|
||||
assert len(registry._handlers) == 1
|
||||
assert isinstance(registry._handlers[0], EmailDeliveryTestHandler)
|
||||
@ -260,7 +261,7 @@ class TestEmailDeliveryTestHandler:
|
||||
)
|
||||
assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"]
|
||||
|
||||
def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
account = Account(name="Test User", email="member@example.com")
|
||||
db_session_with_containers.add(account)
|
||||
@ -282,7 +283,7 @@ class TestEmailDeliveryTestHandler:
|
||||
)
|
||||
assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"]
|
||||
|
||||
def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers):
|
||||
def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers: Session):
|
||||
tenant_id = str(uuid4())
|
||||
account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com")
|
||||
account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com")
|
||||
|
||||
@ -165,7 +165,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
|
||||
return app
|
||||
|
||||
def _create_conversation(self, db_session_with_containers: Session, app):
|
||||
def _create_conversation(self, db_session_with_containers: Session, app: App):
|
||||
"""Helper to create a conversation."""
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
|
||||
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import Dataset, DatasetMetadataBinding, Document
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
@ -65,7 +66,7 @@ class TestMetadataPartialUpdate:
|
||||
yield account
|
||||
|
||||
def test_partial_update_merges_metadata(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||
document = _create_document(
|
||||
@ -92,7 +93,7 @@ class TestMetadataPartialUpdate:
|
||||
assert updated_doc.doc_metadata["new_key"] == "new_value"
|
||||
|
||||
def test_full_update_replaces_metadata(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||
document = _create_document(
|
||||
@ -119,7 +120,7 @@ class TestMetadataPartialUpdate:
|
||||
assert "existing_key" not in updated_doc.doc_metadata
|
||||
|
||||
def test_partial_update_skips_existing_binding(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, user_id, mock_current_account
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||
document = _create_document(
|
||||
@ -159,7 +160,7 @@ class TestMetadataPartialUpdate:
|
||||
assert len(bindings) == 1
|
||||
|
||||
def test_rollback_called_on_commit_failure(
|
||||
self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account
|
||||
self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id)
|
||||
document = _create_document(
|
||||
|
||||
@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from models.model import OAuthProviderApp
|
||||
@ -25,7 +26,7 @@ from services.oauth_server import (
|
||||
class TestOAuthServerServiceGetProviderApp:
|
||||
"""DB-backed tests for get_oauth_provider_app."""
|
||||
|
||||
def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp:
|
||||
def _create_oauth_provider_app(self, db_session_with_containers: Session, *, client_id: str) -> OAuthProviderApp:
|
||||
app = OAuthProviderApp(
|
||||
app_icon="icon.png",
|
||||
client_id=client_id,
|
||||
@ -38,7 +39,7 @@ class TestOAuthServerServiceGetProviderApp:
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers):
|
||||
def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers: Session):
|
||||
client_id = f"client-{uuid4()}"
|
||||
created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id)
|
||||
|
||||
@ -48,7 +49,7 @@ class TestOAuthServerServiceGetProviderApp:
|
||||
assert result.client_id == client_id
|
||||
assert result.id == created.id
|
||||
|
||||
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers):
|
||||
def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers: Session):
|
||||
result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}")
|
||||
|
||||
assert result is None
|
||||
|
||||
@ -8,6 +8,7 @@ from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowPause, WorkflowRun
|
||||
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
|
||||
@ -39,7 +40,7 @@ class TestWorkflowRunRestore:
|
||||
assert result["created_at"].month == 1
|
||||
assert result["name"] == "test"
|
||||
|
||||
def test_restore_table_records_returns_rowcount(self, db_session_with_containers):
|
||||
def test_restore_table_records_returns_rowcount(self, db_session_with_containers: Session):
|
||||
"""Restore should return inserted rowcount."""
|
||||
restore = WorkflowRunRestore()
|
||||
record_id = str(uuid4())
|
||||
@ -65,7 +66,7 @@ class TestWorkflowRunRestore:
|
||||
restored_pause = db_session_with_containers.scalar(select(WorkflowPause).where(WorkflowPause.id == record_id))
|
||||
assert restored_pause is not None
|
||||
|
||||
def test_restore_table_records_unknown_table(self, db_session_with_containers):
|
||||
def test_restore_table_records_unknown_table(self, db_session_with_containers: Session):
|
||||
"""Unknown table names should be ignored gracefully."""
|
||||
restore = WorkflowRunRestore()
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user