merge hitl

This commit is contained in:
JzoNg 2026-05-09 13:09:54 +08:00
commit a6994cc680
1095 changed files with 52485 additions and 23955 deletions

View File

@ -1,6 +1,6 @@
---
name: frontend-query-mutation
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
---
# Frontend Query & Mutation
@ -9,22 +9,24 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
- Keep contract as the single source of truth in `web/contract/*`.
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
- Keep invalidation and mutation flow knowledge in the service layer.
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
- Keep abstractions minimal to preserve TypeScript inference.
## Workflow
1. Identify the change surface.
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
- Read both references when a task spans contract shape and runtime behavior.
2. Implement the smallest abstraction that fits the task.
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
- Extract a small shared query helper only when multiple call sites share the same extra options.
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
- Create or keep feature hooks only for real orchestration or shared domain behavior.
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
3. Preserve Dify conventions.
- Keep contract inputs in `{ params, query?, body? }` shape.
- Bind invalidation in the service-layer mutation definition.
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
## Files Commonly Touched
@ -33,7 +35,7 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
- `web/contract/marketplace.ts`
- `web/contract/router.ts`
- `web/service/client.ts`
- `web/service/use-*.ts`
- legacy `web/service/use-*.ts` files when migrating wrappers away
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
## References

View File

@ -1,4 +1,4 @@
interface:
display_name: "Frontend Query & Mutation"
short_description: "Dify TanStack Query and oRPC patterns"
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
short_description: "Dify TanStack Query, oRPC, and default option patterns"
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."

View File

@ -7,6 +7,7 @@
- Core workflow
- Query usage decision rule
- Mutation usage decision rule
- Thin hook decision rule
- Anti-patterns
- Contract rules
- Type export
@ -55,9 +56,13 @@ const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
1. Default to direct `*.queryOptions(...)` usage at the call site.
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
3. Create `web/service/use-{domain}.ts` only for orchestration.
3. Create or keep feature hooks only for orchestration.
- Combine multiple queries or mutations.
- Share domain-level derived state or invalidation helpers.
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
4. Treat `web/service/use-{domain}.ts` as legacy.
- Do not create new thin service wrappers for oRPC contracts.
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
```typescript
const invoicesBaseQueryOptions = () =>
@ -74,11 +79,37 @@ const invoiceQuery = useQuery({
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
```typescript
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
```
## Thin Hook Decision Rule
Remove thin hooks when they only rename a single oRPC query or mutation helper.
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
Use:
```typescript
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
```
Keep:
```typescript
const applyTagBindingsMutation = useApplyTagBindingsMutation()
```
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
## Anti-Patterns
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
- Do not create thin `use-*` passthrough hooks for a single endpoint.
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
## Contract Rules

View File

@ -3,6 +3,7 @@
## Table of Contents
- Conditional queries
- oRPC default options
- Cache invalidation
- Key API guide
- `mutate` vs `mutateAsync`
@ -35,9 +36,50 @@ function useBadAccessMode(appId: string | undefined) {
}
```
## oRPC Default Options
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
Place defaults at the query utility creation point in `web/service/client.ts`:
```typescript
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
path: ['console'],
experimental_defaults: {
tags: {
create: {
mutationOptions: {
onSuccess: (tag, _variables, _result, context) => {
context.client.setQueryData(
consoleQuery.tags.list.queryKey({
input: {
query: {
type: tag.type,
},
},
}),
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
)
},
},
},
},
},
})
```
Rules:
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
## Cache Invalidation
Bind invalidation in the service-layer mutation definition.
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
Use:
@ -49,7 +91,7 @@ Use:
Do not use deprecated `useInvalid` from `use-base.ts`.
```typescript
// Service layer owns cache invalidation.
// Feature orchestration owns cache invalidation only when defaults are not enough.
export const useUpdateAccessMode = () => {
const queryClient = useQueryClient()
@ -124,7 +166,7 @@ When touching old code, migrate it toward these rules:
| Old pattern | New pattern |
|---|---|
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |

View File

@ -99,7 +99,7 @@ jobs:
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/middleware.env.example docker/middleware.env
cp docker/envs/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh

View File

@ -116,6 +116,12 @@ jobs:
if: github.event_name != 'merge_group'
uses: ./.github/actions/setup-web
- name: Generate API docs
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
run: |
cd api
uv run dev/generate_swagger_markdown_docs.py --swagger-dir openapi --markdown-dir openapi/markdown
- name: ESLint autofix
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
run: |

View File

@ -37,7 +37,7 @@ jobs:
- name: Prepare middleware env
run: |
cd docker
cp middleware.env.example middleware.env
cp envs/middleware.env.example middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
@ -87,7 +87,7 @@ jobs:
- name: Prepare middleware env for MySQL
run: |
cd docker
cp middleware.env.example middleware.env
cp envs/middleware.env.example middleware.env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env

View File

@ -57,7 +57,7 @@ jobs:
- '.github/workflows/api-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/middleware.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/docker-compose.middleware.yaml'
- 'docker/docker-compose-template.yaml'
- 'docker/generate_docker_compose'
@ -84,7 +84,7 @@ jobs:
- 'pnpm-workspace.yaml'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'
- 'docker/envs/middleware.env.example'
- '.github/workflows/web-e2e.yml'
- '.github/actions/setup-web/**'
vdb:
@ -94,7 +94,7 @@ jobs:
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/middleware.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/docker-compose.yaml'
- 'docker/docker-compose-template.yaml'
- 'docker/generate_docker_compose'
@ -116,7 +116,7 @@ jobs:
- '.github/workflows/db-migration-test.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'
- 'docker/middleware.env.example'
- 'docker/envs/middleware.env.example'
- 'docker/docker-compose.middleware.yaml'
- 'docker/docker-compose-template.yaml'
- 'docker/generate_docker_compose'

View File

@ -107,6 +107,8 @@ jobs:
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web type check

View File

@ -51,7 +51,7 @@ jobs:
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/middleware.env.example docker/middleware.env
cp docker/envs/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh

View File

@ -48,7 +48,7 @@ jobs:
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/middleware.env.example docker/middleware.env
cp docker/envs/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh

View File

@ -71,13 +71,13 @@ type-check:
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@./dev/pyrefly-check-local
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Type checks complete"
type-check-core:
@echo "📝 Running core type checks (basedpyright + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Core type checks complete"
test:

View File

@ -137,7 +137,7 @@ Star Dify on GitHub and be instantly notified of new releases.
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you need to customize the configuration, edit `docker/.env`. The essential startup defaults live in [`docker/.env.example`](docker/.env.example), and optional advanced variables are split under `docker/envs/` by theme. After making any changes, re-run `docker compose up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana

View File

@ -98,6 +98,8 @@ DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true
SQLALCHEMY_POOL_TIMEOUT=30
# Connection pool reset behavior on return
SQLALCHEMY_POOL_RESET_ON_RETURN=rollback
# Storage configuration
# use for store upload files, private keys...
@ -381,7 +383,7 @@ VIKINGDB_ACCESS_KEY=your-ak
VIKINGDB_SECRET_KEY=your-sk
VIKINGDB_REGION=cn-shanghai
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
VIKINGDB_SCHEMA=http
VIKINGDB_SCHEME=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
@ -432,8 +434,6 @@ UPLOAD_FILE_EXTENSION_BLACKLIST=
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp, sendgrid

View File

@ -17,7 +17,7 @@ FROM base AS packages
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
# basic environment
g++ \
git g++ \
# for building gmpy2
libmpfr-dev libmpc-dev

View File

@ -114,7 +114,7 @@ class SQLAlchemyEngineOptionsDict(TypedDict):
pool_pre_ping: bool
connect_args: dict[str, str]
pool_use_lifo: bool
pool_reset_on_return: None
pool_reset_on_return: Literal["commit", "rollback", None]
pool_timeout: int
@ -223,6 +223,11 @@ class DatabaseConfig(BaseSettings):
default=30,
)
SQLALCHEMY_POOL_RESET_ON_RETURN: Literal["commit", "rollback", None] = Field(
description="Connection pool reset behavior on return. Options: 'commit', 'rollback', or None",
default="rollback",
)
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
description="Number of processes for the retrieval service, default to CPU cores.",
default=os.cpu_count() or 1,
@ -252,7 +257,7 @@ class DatabaseConfig(BaseSettings):
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": connect_args,
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
"pool_reset_on_return": None,
"pool_reset_on_return": self.SQLALCHEMY_POOL_RESET_ON_RETURN,
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
}
return result

View File

@ -19,7 +19,7 @@
"name": "Website Generator"
},
"app_id": "b53545b1-79ea-4da3-b31a-c39391c6f041",
"category": "Programming",
"categories": ["Programming"],
"copyright": null,
"description": null,
"is_listed": true,
@ -35,7 +35,7 @@
"name": "Investment Analysis Report Copilot"
},
"app_id": "a23b57fa-85da-49c0-a571-3aff375976c1",
"category": "Agent",
"categories": ["Agent"],
"copyright": "Dify.AI",
"description": "Welcome to your personalized Investment Analysis Copilot service, where we delve into the depths of stock analysis to provide you with comprehensive insights. \n",
"is_listed": true,
@ -51,7 +51,7 @@
"name": "Workflow Planning Assistant "
},
"app_id": "f3303a7d-a81c-404e-b401-1f8711c998c1",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "An assistant that helps you plan and select the right node for a workflow (V0.6.0). ",
"is_listed": true,
@ -67,7 +67,7 @@
"name": "Automated Email Reply "
},
"app_id": "e9d92058-7d20-4904-892f-75d90bef7587",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "Reply emails using Gmail API. It will automatically retrieve email in your inbox and create a response in Gmail. \nConfigure your Gmail API in Google Cloud Console. ",
"is_listed": true,
@ -83,7 +83,7 @@
"name": "Book Translation "
},
"app_id": "98b87f88-bd22-4d86-8b74-86beba5e0ed4",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "A workflow designed to translate a full book up to 15000 tokens per run. Uses Code node to separate text into chunks and Iteration to translate each chunk. ",
"is_listed": true,
@ -99,7 +99,7 @@
"name": "Python bug fixer"
},
"app_id": "cae337e6-aec5-4c7b-beca-d6f1a808bd5e",
"category": "Programming",
"categories": ["Programming"],
"copyright": null,
"description": null,
"is_listed": true,
@ -115,7 +115,7 @@
"name": "Code Interpreter"
},
"app_id": "d077d587-b072-4f2c-b631-69ed1e7cdc0f",
"category": "Programming",
"categories": ["Programming"],
"copyright": "Copyright 2023 Dify",
"description": "Code interpreter, clarifying the syntax and semantics of the code.",
"is_listed": true,
@ -131,7 +131,7 @@
"name": "SVG Logo Design "
},
"app_id": "73fbb5f1-c15d-4d74-9cc8-46d9db9b2cca",
"category": "Agent",
"categories": ["Agent"],
"copyright": "Dify.AI",
"description": "Hello, I am your creative partner in bringing ideas to vivid life! I can assist you in creating stunning designs by leveraging abilities of DALL·E 3. ",
"is_listed": true,
@ -147,7 +147,7 @@
"name": "Long Story Generator (Iteration) "
},
"app_id": "5efb98d7-176b-419c-b6ef-50767391ab62",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "A workflow demonstrating how to use Iteration node to generate long article that is longer than the context length of LLMs. ",
"is_listed": true,
@ -163,7 +163,7 @@
"name": "Text Summarization Workflow"
},
"app_id": "f00c4531-6551-45ee-808f-1d7903099515",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "Based on users' choice, retrieve external knowledge to more accurately summarize articles.",
"is_listed": true,
@ -179,7 +179,7 @@
"name": "YouTube Channel Data Analysis"
},
"app_id": "be591209-2ca8-410f-8f3b-ca0e530dd638",
"category": "Agent",
"categories": ["Agent"],
"copyright": "Dify.AI",
"description": "I am a YouTube Channel Data Analysis Copilot, I am here to provide expert data analysis tailored to your needs. ",
"is_listed": true,
@ -195,7 +195,7 @@
"name": "Article Grading Bot"
},
"app_id": "a747f7b4-c48b-40d6-b313-5e628232c05f",
"category": "Writing",
"categories": ["Writing"],
"copyright": null,
"description": "Assess the quality of articles and text based on user defined criteria. ",
"is_listed": true,
@ -211,7 +211,7 @@
"name": "SEO Blog Generator"
},
"app_id": "18f3bd03-524d-4d7a-8374-b30dbe7c69d5",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "Workflow for retrieving information from the internet, followed by segmented generation of SEO blogs.",
"is_listed": true,
@ -227,7 +227,7 @@
"name": "SQL Creator"
},
"app_id": "050ef42e-3e0c-40c1-a6b6-a64f2c49d744",
"category": "Programming",
"categories": ["Programming"],
"copyright": "Copyright 2023 Dify",
"description": "Write SQL from natural language by pasting in your schema with the request.Please describe your query requirements in natural language and select the target database type.",
"is_listed": true,
@ -243,7 +243,7 @@
"name": "Sentiment Analysis "
},
"app_id": "f06bf86b-d50c-4895-a942-35112dbe4189",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "Batch sentiment analysis of text, followed by JSON output of sentiment classification along with scores.",
"is_listed": true,
@ -259,7 +259,7 @@
"name": "Strategic Consulting Expert"
},
"app_id": "7e8ca1ae-02f2-4b5f-979e-62d19133bee2",
"category": "Assistant",
"categories": ["Assistant"],
"copyright": "Copyright 2023 Dify",
"description": "I can answer your questions related to strategic marketing.",
"is_listed": true,
@ -275,7 +275,7 @@
"name": "Code Converter"
},
"app_id": "4006c4b2-0735-4f37-8dbb-fb1a8c5bd87a",
"category": "Programming",
"categories": ["Programming"],
"copyright": "Copyright 2023 Dify",
"description": "This is an application that provides the ability to convert code snippets in multiple programming languages. You can input the code you wish to convert, select the target programming language, and get the desired output.",
"is_listed": true,
@ -291,7 +291,7 @@
"name": "Question Classifier + Knowledge + Chatbot "
},
"app_id": "d9f6b733-e35d-4a40-9f38-ca7bbfa009f7",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "Basic Workflow Template, a chatbot capable of identifying intents alongside with a knowledge base.",
"is_listed": true,
@ -307,7 +307,7 @@
"name": "AI Front-end interviewer"
},
"app_id": "127efead-8944-4e20-ba9d-12402eb345e0",
"category": "HR",
"categories": ["HR"],
"copyright": "Copyright 2023 Dify",
"description": "A simulated front-end interviewer that tests the skill level of front-end development through questioning.",
"is_listed": true,
@ -323,7 +323,7 @@
"name": "Knowledge Retrieval + Chatbot "
},
"app_id": "e9870913-dd01-4710-9f06-15d4180ca1ce",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "Basic Workflow Template, A chatbot with a knowledge base. ",
"is_listed": true,
@ -339,7 +339,7 @@
"name": "Email Assistant Workflow "
},
"app_id": "dd5b6353-ae9b-4bce-be6a-a681a12cf709",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "A multifunctional email assistant capable of summarizing, replying, composing, proofreading, and checking grammar.",
"is_listed": true,
@ -355,7 +355,7 @@
"name": "Customer Review Analysis Workflow "
},
"app_id": "9c0cd31f-4b62-4005-adf5-e3888d08654a",
"category": "Workflow",
"categories": ["Workflow"],
"copyright": null,
"description": "Utilize LLM (Large Language Models) to classify customer reviews and forward them to the internal system.",
"is_listed": true,

View File

@ -1,6 +1,36 @@
from pydantic import BaseModel, JsonValue
from pydantic import BaseModel, Field, JsonValue
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
"decision": "approve",
"attachment": {
"transfer_method": "local_file",
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
"type": "document",
},
"attachments": [
{
"transfer_method": "local_file",
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
"type": "document",
},
{
"transfer_method": "remote_url",
"url": "https://example.com/report.pdf",
"type": "document",
},
],
}
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
inputs: dict[str, JsonValue] = Field(
description=(
"Submitted human input values keyed by output variable name. "
"Use a string for paragraph or select input values, a file mapping for file inputs, "
"and a list of file mappings for file-list inputs. Local file mappings use "
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
"`transfer_method=remote_url` with `url` or `remote_url`."
),
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
)
action: str

View File

@ -1,4 +1,10 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
"""Helpers for registering Pydantic models with Flask-RESTX namespaces.
Flask-RESTX treats `SchemaModel` bodies as opaque JSON schemas; it does not
promote Pydantic's nested `$defs` into top-level Swagger `definitions`.
These helpers keep that translation centralized so models registered through
`register_schema_models` emit resolvable Swagger 2.0 references.
"""
from enum import StrEnum
@ -8,10 +14,32 @@ from pydantic import BaseModel, TypeAdapter
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a single BaseModel with a namespace for Swagger documentation."""
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
nested_definitions = schema.get("$defs")
schema_to_register = dict(schema)
if isinstance(nested_definitions, dict):
schema_to_register.pop("$defs")
namespace.schema_model(name, schema_to_register)
if not isinstance(nested_definitions, dict):
return
for nested_name, nested_schema in nested_definitions.items():
if isinstance(nested_schema, dict):
_register_json_schema(namespace, nested_name, nested_schema)
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a BaseModel and its nested schema definitions for Swagger documentation."""
_register_json_schema(
namespace,
model.__name__,
model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
@ -34,8 +62,10 @@ def get_or_create_model(model_name: str, field_def):
def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
"""Register multiple StrEnum with a namespace."""
for model in models:
namespace.schema_model(
model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
_register_json_schema(
namespace,
model.__name__,
TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)

View File

@ -12,6 +12,7 @@ from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
@ -301,15 +302,7 @@ class BatchAddNotificationAccountsPayload(BaseModel):
user_email: list[str] = Field(..., description="List of account email addresses")
console_ns.schema_model(
UpsertNotificationPayload.__name__,
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
BatchAddNotificationAccountsPayload.__name__,
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
register_schema_models(console_ns, UpsertNotificationPayload, BatchAddNotificationAccountsPayload)
@console_ns.route("/admin/upsert_notification")

View File

@ -1,4 +1,5 @@
import logging
import re
import uuid
from datetime import datetime
from typing import Any, Literal
@ -8,6 +9,7 @@ from flask_restx import Resource
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.datastructures import MultiDict
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@ -23,6 +25,7 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -58,6 +61,7 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
class AppListQuery(BaseModel):
@ -67,22 +71,19 @@ class AppListQuery(BaseModel):
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
def validate_tag_ids(cls, value: list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not isinstance(value, list):
raise ValueError("Unsupported tag_ids type.")
items = [str(item).strip() for item in value if item and str(item).strip()]
if not items:
return None
@ -92,6 +93,26 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
normalized: dict[str, str | list[str]] = {}
indexed_tag_ids: list[tuple[int, str]] = []
for key in query_args:
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
if match:
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
continue
value = query_args.get(key)
if value is not None:
normalized[key] = value
if indexed_tag_ids:
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
return normalized
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@ -460,7 +481,7 @@ class AppListApi(Resource):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
args_dict = args.model_dump()
# get app list
@ -857,7 +878,8 @@ class AppTraceApi(Resource):
@account_initialization_required
def get(self, app_id):
"""Get app trace"""
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
with session_factory.create_session() as session:
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id, session)
return app_trace_config

View File

@ -2,7 +2,7 @@ from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@ -33,6 +33,7 @@ class AppImportPayload(BaseModel):
app_id: str | None = Field(None)
register_enum_models(console_ns, ImportStatus)
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)

View File

@ -3,6 +3,7 @@ from collections.abc import Sequence
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@ -19,13 +20,12 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
@ -41,16 +41,16 @@ class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(ModelConfig)
register_enum_models(console_ns, LLMMode)
register_schema_models(
console_ns,
RuleGeneratePayload,
RuleCodeGeneratePayload,
RuleStructuredOutputPayload,
InstructionGeneratePayload,
InstructionTemplatePayload,
ModelConfig,
)
@console_ns.route("/rule-generate")

View File

@ -75,14 +75,15 @@ console_ns.schema_model(
def _convert_values_to_json_serializable_object(value: Segment):
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
match value:
case FileSegment():
return value.value.model_dump()
case ArrayFileSegment():
return [i.model_dump() for i in value.value]
case SegmentGroup():
return [_convert_values_to_json_serializable_object(i) for i in value.value]
case _:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable):

View File

@ -52,7 +52,7 @@ class RecommendedAppResponse(ResponseModel):
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
category: str | None = None
categories: list[str] = Field(default_factory=list)
position: int | None = None
is_listed: bool | None = None
can_trial: bool | None = None

View File

@ -32,12 +32,7 @@ class TagBindingPayload(BaseModel):
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1)
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
@ -75,7 +70,6 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam,
TagResponse,
)
@ -184,13 +178,13 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
@ -211,54 +205,15 @@ class TagBindingCollectionApi(Resource):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/<uuid:id>")
class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
class TagBindingRemoveApi(Resource):
"""Batch resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
@console_ns.doc("remove_tag_bindings")
@console_ns.doc(description="Remove one or more tag bindings from a target.")
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _remove_tag_binding()
return _remove_tag_bindings()

View File

@ -876,10 +876,10 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@account_initialization_required
def post(self, provider):
current_user, current_tenant_id = current_account_with_tenant()
_, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
tenant_id=current_tenant_id, provider=provider, id=payload.id
)

View File

@ -2,7 +2,7 @@ from typing import Any, Literal, cast
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel):
class TagUnbindingPayload(BaseModel):
tag_id: str
"""Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally."""
tag_ids: list[str] = Field(default_factory=list)
tag_id: str | None = None
target_id: str
@model_validator(mode="before")
@classmethod
def normalize_legacy_tag_id(cls, data: object) -> object:
if not isinstance(data, dict):
return data
if not data.get("tag_ids") and data.get("tag_id"):
return {**data, "tag_ids": [data["tag_id"]]}
return data
@model_validator(mode="after")
def validate_tag_ids(self) -> "TagUnbindingPayload":
if not self.tag_ids:
raise ValueError("Tag IDs is required.")
return self
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource):
@service_api_ns.route("/datasets/tags/unbinding")
class DatasetTagUnbindingApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
@service_api_ns.doc("unbind_dataset_tag")
@service_api_ns.doc(description="Unbind a tag from a dataset")
@service_api_ns.doc("unbind_dataset_tags")
@service_api_ns.doc(description="Unbind tags from a dataset")
@service_api_ns.doc(
responses={
204: "Tag unbound successfully",
204: "Tags unbound successfully",
401: "Unauthorized - invalid API token",
403: "Forbidden - insufficient permissions",
}
@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
return "", 204

View File

@ -23,6 +23,7 @@ from . import (
feature,
files,
forgot_password,
human_input_file_upload,
human_input_form,
login,
message,
@ -46,6 +47,7 @@ __all__ = [
"feature",
"files",
"forgot_password",
"human_input_file_upload",
"human_input_form",
"login",
"message",

View File

@ -23,7 +23,7 @@ from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App
from models.model import App, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -69,12 +69,12 @@ class AudioApi(WebApiResource):
500: "Internal Server Error",
}
)
def post(self, app_model: App, end_user):
def post(self, app_model: App, end_user: EndUser):
"""Convert audio to text"""
file = request.files["file"]
try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.external_user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
500: "Internal Server Error",
}
)
def post(self, app_model: App, end_user):
def post(self, app_model: App, end_user: EndUser):
"""Convert text to audio"""
try:
payload = TextToAudioPayload.model_validate(web_ns.payload or {})

View File

@ -0,0 +1,181 @@
import httpx
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, HttpUrl
import services
from controllers.common import helpers
from controllers.common.errors import (
BlockedFileExtensionError,
FileTooLargeError,
NoFileUploadedError,
RemoteFileUploadError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileResponse, FileWithSignedUrl
from graphon.file import helpers as file_helpers
from libs.exception import BaseHTTPException
from services.file_service import FileService
from services.human_input_file_upload_service import (
HITL_UPLOAD_TOKEN_PREFIX,
HumanInputFileUploadService,
InvalidUploadTokenError,
)
class InvalidUploadTokenBadRequestError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Invalid upload token."
code = 400
class InvalidUploadTokenUnauthorizedError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Upload token is required."
code = 401
class InvalidUploadTokenForbiddenError(BaseHTTPException):
error_code = "invalid_upload_token"
description = "Upload token is invalid or expired."
code = 403
class HumanInputRemoteFileUploadPayload(BaseModel):
url: HttpUrl = Field(description="Remote file URL")
register_schema_models(web_ns, HumanInputRemoteFileUploadPayload, FileResponse, FileWithSignedUrl)
def _extract_hitl_upload_token() -> str:
"""Read HITL upload token from Authorization without invoking other bearer auth chains."""
authorization = request.headers.get("Authorization")
if authorization is None:
raise InvalidUploadTokenUnauthorizedError()
parts = authorization.split()
if len(parts) != 2:
raise InvalidUploadTokenUnauthorizedError()
scheme, token = parts
if scheme.lower() != "bearer":
raise InvalidUploadTokenBadRequestError()
if not token:
raise InvalidUploadTokenUnauthorizedError()
if not token.startswith(HITL_UPLOAD_TOKEN_PREFIX):
raise InvalidUploadTokenBadRequestError()
return token
def _validate_context(service: HumanInputFileUploadService, token: str):
try:
return service.validate_upload_token(token)
except InvalidUploadTokenError as exc:
raise InvalidUploadTokenForbiddenError() from exc
def _parse_local_upload_file():
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.filename:
from controllers.common.errors import FilenameNotExistsError
raise FilenameNotExistsError()
return file
@web_ns.route("/form/human_input/files/upload")
class HumanInputFileUploadApi(Resource):
def post(self):
"""Upload one local file for a HITL human input form."""
token = _extract_hitl_upload_token()
upload_service = HumanInputFileUploadService(db.engine)
context = _validate_context(upload_service, token)
file = _parse_local_upload_file()
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename or "",
content=file.read(),
mimetype=file.mimetype,
user=context.owner,
source=None,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError() from exc
upload_service.record_upload_file(context=context, file_id=upload_file.id)
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201
@web_ns.route("/form/human_input/files/remote-upload")
class HumanInputRemoteFileUploadApi(Resource):
def post(self):
"""Upload one remote URL file for a HITL human input form."""
token = _extract_hitl_upload_token()
upload_service = HumanInputFileUploadService(db.engine)
context = _validate_context(upload_service, token)
payload = HumanInputRemoteFileUploadPayload.model_validate(request.get_json(silent=True) or {})
url = str(payload.url)
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as exc:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(exc)}")
file_info = helpers.guess_file_info_from_response(resp)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError()
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=context.owner,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError() from exc
upload_service.record_upload_file(context=context, file_id=upload_file.id)
payload1 = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
return payload1.model_dump(mode="json"), 201

View File

@ -9,11 +9,13 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
@ -21,11 +23,20 @@ from extensions.ext_database import db
from libs.helper import RateLimiter, extract_remote_ip
from models.account import TenantStatus
from models.model import App, Site
from services.human_input_file_upload_service import HumanInputFileUploadService
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
class HumanInputUploadTokenResponse(BaseModel):
upload_token: str
expires_at: int
register_schema_models(web_ns, HumanInputUploadTokenResponse)
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
@ -36,6 +47,11 @@ _FORM_ACCESS_RATE_LIMITER = RateLimiter(
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
_FORM_UPLOAD_TOKEN_RATE_LIMITER = RateLimiter(
prefix="web_form_upload_token_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
@ -78,6 +94,33 @@ def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Re
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
@web_ns.route("/form/human_input/<string:form_token>/upload-token")
class HumanInputFormUploadTokenApi(Resource):
"""API for issuing HITL upload tokens for active human input forms."""
def post(self, form_token: str):
"""
Issue an upload token for a human input form.
POST /api/form/human_input/<form_token>/upload-token
"""
ip_address = extract_remote_ip(request)
if _FORM_UPLOAD_TOKEN_RATE_LIMITER.is_rate_limited(ip_address):
raise WebFormRateLimitExceededError()
_FORM_UPLOAD_TOKEN_RATE_LIMITER.increment_rate_limit(ip_address)
try:
token = HumanInputFileUploadService(db.engine).issue_upload_token(form_token)
except FormNotFoundError:
raise NotFoundError("Form not found")
response = HumanInputUploadTokenResponse(
upload_token=token.upload_token,
expires_at=_to_timestamp(token.expires_at),
)
return response.model_dump(mode="json"), 200
@web_ns.route("/form/human_input/<string:form_token>")
class HumanInputFormApi(Resource):
"""API for getting and submitting human input forms via the web app."""

View File

@ -408,17 +408,19 @@ class WorkflowResponseConverter:
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
) -> HumanInputFormFilledResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormFilledResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
),
data = HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
if event.submitted_data is not None:
runtime_type_converter = WorkflowRuntimeTypeConverter()
data.submitted_data = runtime_type_converter.value_to_json_encodable_recursive(event.submitted_data)
return HumanInputFormFilledResponse(task_id=task_id, workflow_run_id=run_id, data=data)
def human_input_form_timeout_to_stream_response(
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
@ -842,24 +844,24 @@ class WorkflowResponseConverter:
return []
files: list[Mapping[str, Any]] = []
if isinstance(value, FileSegment):
files.append(value.value.to_dict())
elif isinstance(value, ArrayFileSegment):
files.extend([i.to_dict() for i in value.value])
elif isinstance(value, File):
files.append(value.to_dict())
elif isinstance(value, list):
for item in value:
file = cls._get_file_var_from_value(item)
match value:
case FileSegment():
files.append(value.value.to_dict())
case ArrayFileSegment():
files.extend([i.to_dict() for i in value.value])
case File():
files.append(value.to_dict())
case list():
for item in value:
file = cls._get_file_var_from_value(item)
if file:
files.append(file)
case dict():
file = cls._get_file_var_from_value(value)
if file:
files.append(file)
elif isinstance(
value,
dict,
):
file = cls._get_file_var_from_value(value)
if file:
files.append(file)
case _:
pass
return files

View File

@ -432,6 +432,7 @@ class WorkflowBasedAppRunner:
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
submitted_data=event.submitted_data,
)
)
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):

View File

@ -11,6 +11,7 @@ from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReason
from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.variables.segments import Segment
class QueueEvent(StrEnum):
@ -508,6 +509,10 @@ class QueueHumanInputFormFilledEvent(AppQueueEvent):
action_id: str
action_text: str
# Keep the field name aligned with Graphon so the app-layer bridge does not
# need to translate between two equivalent payload names.
submitted_data: Mapping[str, Segment] | None = None
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
"""

View File

@ -10,7 +10,7 @@ from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInput, UserAction
from graphon.nodes.human_input.entities import FormInputConfig, UserActionConfig
class AnnotationReplyAccount(BaseModel):
@ -284,8 +284,8 @@ class HumanInputRequiredResponse(StreamResponse):
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
inputs: Sequence[FormInputConfig] = Field(default_factory=list)
actions: Sequence[UserActionConfig] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
@ -307,8 +307,8 @@ class HumanInputRequiredPauseReasonPayload(BaseModel):
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
inputs: Sequence[FormInputConfig] = Field(default_factory=list)
actions: Sequence[UserActionConfig] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
@ -342,6 +342,8 @@ class HumanInputFormFilledResponse(StreamResponse):
action_id: str
action_text: str
submitted_data: Mapping[str, Any] | None = None
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
workflow_run_id: str
data: Data

View File

@ -3,9 +3,9 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from graphon.nodes.human_input.entities import FormInput, UserAction
from graphon.nodes.human_input.entities import FormInputConfig, UserActionConfig
from models.execution_extra_content import ExecutionContentType
@ -16,9 +16,11 @@ class HumanInputFormDefinition(BaseModel):
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
inputs: Sequence[FormInputConfig] = Field(default_factory=list)
actions: Sequence[UserActionConfig] = Field(default_factory=list)
display_in_ui: bool = False
# `form_token` is `None` if the corresponding form has been submitted.
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@ -29,16 +31,31 @@ class HumanInputFormSubmissionData(BaseModel):
node_id: str
node_title: str
# deprecate: the rendered_content is deprecated and only for historical reasons.
rendered_content: str
# The identifier of action user has chosen.
action_id: str
# The button text of the action user has chosen.
action_text: str
# submitted_data records the submitted form data.
# Keys correspond to `output_variable_name` of HumanInput inputs.
# Values are serialized JSON forms of runtime values, including file dictionaries.
#
# For form submitted before this field is introduced, this field is populated from
# the stored submission data.
submitted_data: Mapping[str, JsonValue] | None = None
class HumanInputContent(BaseModel):
model_config = ConfigDict(frozen=True)
workflow_run_id: str
submitted: bool
# Both the form_defintion and the form_submission_data are present in
# HumanInputContent. For historical records, the
form_definition: HumanInputFormDefinition | None = None
form_submission_data: HumanInputFormSubmissionData | None = None
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)

View File

@ -569,13 +569,13 @@ class OpsTraceManager:
db.session.commit()
@classmethod
def get_app_tracing_config(cls, app_id: str):
def get_app_tracing_config(cls, app_id: str, session: Session):
"""
Get app tracing config
:param app_id: app id
:return:
"""
app: App | None = db.session.get(App, app_id)
app: App | None = session.get(App, app_id)
if not app:
raise ValueError("App not found")
if not app.tracing:

View File

@ -53,24 +53,27 @@ class PromptMessageUtil:
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
files.append(
{
"type": "image",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"detail": content.detail.value,
}
)
elif isinstance(content, AudioPromptMessageContent):
files.append(
{
"type": "audio",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"format": content.format,
}
)
match content:
case TextPromptMessageContent():
text += content.data
case ImagePromptMessageContent():
files.append(
{
"type": "image",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"detail": content.detail.value,
}
)
case AudioPromptMessageContent():
files.append(
{
"type": "audio",
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
"format": content.format,
}
)
case _:
continue
else:
text = cast(str, prompt_message.content)

View File

@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
@ -445,7 +445,7 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
provider_name_to_provider_records_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
providers = session.scalars(stmt)
for provider in providers:
@ -462,7 +462,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_records_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
@ -478,7 +478,7 @@ class ProviderManager:
:return:
"""
provider_name_to_preferred_provider_type_records_dict = {}
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
preferred_provider_types = session.scalars(stmt)
provider_name_to_preferred_provider_type_records_dict = {
@ -496,7 +496,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_settings_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
@ -514,7 +514,7 @@ class ProviderManager:
:return:
"""
provider_name_to_provider_model_credentials_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
provider_model_credentials = session.scalars(stmt)
for provider_model_credential in provider_model_credentials:
@ -544,7 +544,7 @@ class ProviderManager:
return {}
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_configs = session.scalars(stmt)
for provider_load_balancing_config in provider_load_balancing_configs:
@ -578,7 +578,7 @@ class ProviderManager:
:param provider_name: provider name
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = (
select(ProviderCredential)
.where(
@ -608,7 +608,7 @@ class ProviderManager:
:param model_type: model type
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
stmt = (
select(ProviderModelCredential)
.where(

View File

@ -1078,6 +1078,13 @@ class ToolManager:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
if variable_pool:
config = tool_configurations.get(parameter.name, {})
selector_value = cls._extract_runtime_selector_value(parameter, config)
if selector_value is not None:
# Selector parameters carry structured dictionaries, not scalar ToolInput values.
runtime_parameters[parameter.name] = selector_value
continue
if not (config and isinstance(config, dict) and config.get("value") is not None):
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
@ -1105,5 +1112,39 @@ class ToolManager:
runtime_parameters[parameter.name] = value
return runtime_parameters
@classmethod
def _extract_runtime_selector_value(cls, parameter: ToolParameter, config: Any) -> dict[str, Any] | None:
if parameter.type not in {
ToolParameter.ToolParameterType.MODEL_SELECTOR,
ToolParameter.ToolParameterType.APP_SELECTOR,
}:
return None
if not isinstance(config, dict):
return None
input_value = config.get("value")
if isinstance(input_value, dict) and cls._is_selector_value(parameter, input_value):
return cast("dict[str, Any]", parameter.init_frontend_parameter(input_value))
if cls._is_selector_value(parameter, config):
selector_value = dict(config)
selector_value.pop("type", None)
selector_value.pop("value", None)
return cast("dict[str, Any]", parameter.init_frontend_parameter(selector_value))
return None
@classmethod
def _is_selector_value(cls, parameter: ToolParameter, value: Mapping[str, Any]) -> bool:
if parameter.type == ToolParameter.ToolParameterType.MODEL_SELECTOR:
return (
isinstance(value.get("provider"), str)
and isinstance(value.get("model"), str)
and isinstance(value.get("model_type"), str)
)
if parameter.type == ToolParameter.ToolParameterType.APP_SELECTOR:
return isinstance(value.get("app_id"), str)
return False
ToolManager.load_hardcoded_providers_cache()

View File

@ -23,36 +23,37 @@ _TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P<tool_file_id>[^/?#
def safe_json_value(v):
if isinstance(v, datetime):
tz_name = "UTC"
if isinstance(current_user, Account) and current_user.timezone is not None:
tz_name = current_user.timezone
return v.astimezone(pytz.timezone(tz_name)).isoformat()
elif isinstance(v, date):
return v.isoformat()
elif isinstance(v, UUID):
return str(v)
elif isinstance(v, Decimal):
return float(v)
elif isinstance(v, bytes):
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
elif isinstance(v, np.integer):
return int(v)
elif isinstance(v, np.floating):
return float(v)
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
return safe_json_dict(v)
elif isinstance(v, list | tuple | set):
return [safe_json_value(i) for i in v]
else:
return v
match v:
case datetime():
tz_name = "UTC"
if isinstance(current_user, Account) and current_user.timezone is not None:
tz_name = current_user.timezone
return v.astimezone(pytz.timezone(tz_name)).isoformat()
case date():
return v.isoformat()
case UUID():
return str(v)
case Decimal():
return float(v)
case bytes():
try:
return v.decode("utf-8")
except UnicodeDecodeError:
return v.hex()
case memoryview():
return v.tobytes().hex()
case np.integer():
return int(v)
case np.floating():
return float(v)
case np.ndarray():
return v.tolist()
case dict():
return safe_json_dict(v)
case list() | tuple() | set():
return [safe_json_value(i) for i in v]
case _:
return v
def safe_json_dict(d: dict[str, Any]):

View File

@ -272,6 +272,14 @@ def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, A
normalized_tool_configurations[name] = value
continue
selector_value = _extract_selector_configuration(value)
if selector_value is not None:
# Model/app selectors are dictionaries even when they come through the legacy tool configuration path.
# Move them to tool_parameters so graph validation does not flatten them as primitive constants.
found_legacy_tool_inputs = True
normalized_tool_parameters.setdefault(name, {"type": "constant", "value": selector_value})
continue
input_type = value.get("type")
input_value = value.get("value")
if input_type not in {"mixed", "variable", "constant"}:
@ -310,6 +318,28 @@ def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: An
return None
def _extract_selector_configuration(value: Mapping[str, Any]) -> dict[str, Any] | None:
input_value = value.get("value")
if isinstance(input_value, Mapping) and _is_selector_configuration(input_value):
return dict(input_value)
if _is_selector_configuration(value):
selector_value = dict(value)
selector_value.pop("type", None)
selector_value.pop("value", None)
return selector_value
return None
def _is_selector_configuration(value: Mapping[str, Any]) -> bool:
return (
isinstance(value.get("provider"), str)
and isinstance(value.get("model"), str)
and isinstance(value.get("model_type"), str)
) or isinstance(value.get("app_id"), str)
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)

View File

@ -42,7 +42,12 @@ from graphon.model_runtime.entities.llm_entities import (
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.human_input.entities import (
FileInputConfig,
FileListInputConfig,
FormInputConfig,
HumanInputNodeData,
)
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
PromptMessageSerializerProtocol,
@ -78,7 +83,6 @@ from .system_variables import SystemVariableKey, get_system_text
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage
from graphon.file import File
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.tool.entities import ToolNodeData
@ -501,11 +505,15 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
@staticmethod
def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec:
tool_configurations = dict(node_data.tool_configurations)
tool_configurations.update(
{name: tool_input.model_dump(mode="python") for name, tool_input in node_data.tool_parameters.items()}
)
return _WorkflowToolRuntimeSpec(
provider_type=CoreToolProviderType(node_data.provider_type.value),
provider_id=node_data.provider_id,
tool_name=node_data.tool_name,
tool_configurations=dict(node_data.tool_configurations),
tool_configurations=tool_configurations,
credential_id=node_data.credential_id,
)
@ -625,6 +633,7 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
self._run_context = resolve_dify_run_context(run_context)
self._workflow_execution_id_getter = workflow_execution_id_getter
self._form_repository = form_repository
self._file_reference_factory = DifyFileReferenceFactory(self._run_context)
def _invoke_source(self) -> str:
invoke_from = self._run_context.invoke_from
@ -678,6 +687,23 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
repo = self.build_form_repository()
return repo.get_form(node_id)
def restore_submitted_data(
self,
*,
node_data: HumanInputNodeData,
submitted_data: Mapping[str, Any],
) -> Mapping[str, Any]:
restored_data: dict[str, Any] = dict(submitted_data)
for input_config in node_data.inputs:
output_variable_name = input_config.output_variable_name
if output_variable_name not in submitted_data:
continue
restored_data[output_variable_name] = self._restore_submitted_value(
input_config=input_config,
value=submitted_data[output_variable_name],
)
return restored_data
def create_form(
self,
*,
@ -698,6 +724,55 @@ class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol):
)
return repo.create_form(params)
def _restore_submitted_value(
self,
*,
input_config: FormInputConfig,
value: Any,
) -> Any:
if isinstance(input_config, FileInputConfig):
return self._restore_submitted_file_value(
output_variable_name=input_config.output_variable_name,
value=value,
)
if isinstance(input_config, FileListInputConfig):
return self._restore_submitted_file_list_value(
output_variable_name=input_config.output_variable_name,
value=value,
)
return value
def _restore_submitted_file_value(
self,
*,
output_variable_name: str,
value: Any,
) -> Any:
if not isinstance(value, Mapping):
msg = (
"HumanInput file submission must be persisted as a mapping, "
f"output_variable_name={output_variable_name}"
)
raise ValueError(msg)
return self._file_reference_factory.build_from_mapping(mapping=value)
def _restore_submitted_file_list_value(
self,
*,
output_variable_name: str,
value: Any,
) -> list[Any]:
if not isinstance(value, list):
msg = (
"HumanInput file-list submission must be persisted as a list, "
f"output_variable_name={output_variable_name}"
)
raise ValueError(msg)
if any(not isinstance(item, Mapping) for item in value):
msg = f"HumanInput file-list submission must contain mappings, output_variable_name={output_variable_name}"
raise ValueError(msg)
return [self._file_reference_factory.build_from_mapping(mapping=item) for item in value]
def build_dify_llm_file_saver(
*,

View File

@ -0,0 +1,95 @@
"""Generate FastOpenAPI OpenAPI 3.0 specs without booting the full backend."""
from __future__ import annotations
import argparse
import json
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
API_ROOT = Path(__file__).resolve().parents[1]
if str(API_ROOT) not in sys.path:
sys.path.insert(0, str(API_ROOT))
from dev.generate_swagger_specs import apply_runtime_defaults, drop_null_values, sort_openapi_arrays
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class FastOpenApiSpecTarget:
route: str
filename: str
FASTOPENAPI_SPEC_TARGETS: tuple[FastOpenApiSpecTarget, ...] = (
FastOpenApiSpecTarget(route="/fastopenapi/openapi.json", filename="fastopenapi-console-openapi.json"),
)
def create_fastopenapi_spec_app():
"""Build a minimal Flask app that only mounts FastOpenAPI docs routes."""
apply_runtime_defaults()
from app_factory import create_flask_app_with_configs
from extensions import ext_fastopenapi
app = create_flask_app_with_configs()
ext_fastopenapi.init_app(app)
return app
def generate_fastopenapi_specs(output_dir: Path) -> list[Path]:
"""Write FastOpenAPI specs to `output_dir` and return the written paths."""
output_dir.mkdir(parents=True, exist_ok=True)
app = create_fastopenapi_spec_app()
client = app.test_client()
written_paths: list[Path] = []
for target in FASTOPENAPI_SPEC_TARGETS:
response = client.get(target.route)
if response.status_code != 200:
raise RuntimeError(f"failed to fetch {target.route}: {response.status_code}")
payload = response.get_json()
if not isinstance(payload, dict):
raise RuntimeError(f"unexpected response payload for {target.route}")
payload = drop_null_values(payload)
payload = sort_openapi_arrays(payload)
output_path = output_dir / target.filename
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
written_paths.append(output_path)
return written_paths
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"-o",
"--output-dir",
type=Path,
default=Path("openapi"),
help="Directory where the OpenAPI JSON files will be written.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
written_paths = generate_fastopenapi_specs(args.output_dir)
for path in written_paths:
logger.debug(path)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,161 @@
"""Generate OpenAPI JSON specs and split Markdown API docs.
The Markdown step uses `swagger-markdown`, the same converter family as the
Swagger Markdown UI, so CI and local regeneration catch converter-incompatible
OpenAPI output early.
"""
from __future__ import annotations
import argparse
import logging
import subprocess
import sys
import tempfile
from pathlib import Path
API_ROOT = Path(__file__).resolve().parents[1]
if str(API_ROOT) not in sys.path:
sys.path.insert(0, str(API_ROOT))
from dev.generate_fastopenapi_specs import FASTOPENAPI_SPEC_TARGETS, generate_fastopenapi_specs
from dev.generate_swagger_specs import SPEC_TARGETS, generate_specs
logger = logging.getLogger(__name__)
SWAGGER_MARKDOWN_PACKAGE = "swagger-markdown@3.0.0"
CONSOLE_SWAGGER_FILENAME = "console-swagger.json"
STALE_COMBINED_MARKDOWN_FILENAME = "api-reference.md"
def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
subprocess.run(
[
"npx",
"--yes",
SWAGGER_MARKDOWN_PACKAGE,
"-i",
str(spec_path),
"-o",
str(markdown_path),
],
check=True,
)
def _demote_markdown_headings(markdown: str, *, levels: int = 1) -> str:
"""Nest generated Markdown under another Markdown section."""
heading_prefix = "#" * levels
lines = []
for line in markdown.splitlines():
if line.startswith("#"):
lines.append(f"{heading_prefix}{line}")
else:
lines.append(line)
return "\n".join(lines).strip()
def _append_fastopenapi_markdown(console_markdown_path: Path, fastopenapi_markdown_path: Path) -> None:
"""Append FastOpenAPI console docs to the existing console API Markdown."""
console_markdown = console_markdown_path.read_text(encoding="utf-8").rstrip()
fastopenapi_markdown = _demote_markdown_headings(
fastopenapi_markdown_path.read_text(encoding="utf-8"),
levels=2,
)
console_markdown_path.write_text(
"\n\n".join(
[
console_markdown,
"## FastOpenAPI Preview (OpenAPI 3.0)",
fastopenapi_markdown,
]
)
+ "\n",
encoding="utf-8",
)
def generate_markdown_docs(
swagger_dir: Path,
markdown_dir: Path,
*,
keep_swagger_json: bool = False,
) -> list[Path]:
"""Generate intermediate specs, convert them to split Markdown API docs, and return Markdown paths."""
swagger_paths = generate_specs(swagger_dir)
fastopenapi_paths = generate_fastopenapi_specs(swagger_dir)
spec_paths = [*swagger_paths, *fastopenapi_paths]
swagger_paths_by_name = {path.name: path for path in swagger_paths}
fastopenapi_paths_by_name = {path.name: path for path in fastopenapi_paths}
markdown_dir.mkdir(parents=True, exist_ok=True)
written_paths: list[Path] = []
try:
with tempfile.TemporaryDirectory(prefix="dify-api-docs-") as temp_dir:
temp_markdown_dir = Path(temp_dir)
for target in SPEC_TARGETS:
swagger_path = swagger_paths_by_name[target.filename]
markdown_path = markdown_dir / f"{swagger_path.stem}.md"
_convert_spec_to_markdown(swagger_path, markdown_path)
written_paths.append(markdown_path)
for target in FASTOPENAPI_SPEC_TARGETS: # type: ignore
fastopenapi_path = fastopenapi_paths_by_name[target.filename]
markdown_path = temp_markdown_dir / f"{fastopenapi_path.stem}.md"
_convert_spec_to_markdown(fastopenapi_path, markdown_path)
console_markdown_path = markdown_dir / f"{Path(CONSOLE_SWAGGER_FILENAME).stem}.md"
_append_fastopenapi_markdown(console_markdown_path, markdown_path)
(markdown_dir / STALE_COMBINED_MARKDOWN_FILENAME).unlink(missing_ok=True)
finally:
if not keep_swagger_json:
for path in spec_paths:
path.unlink(missing_ok=True)
return written_paths
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--swagger-dir",
type=Path,
default=Path("openapi"),
help="Directory where intermediate JSON spec files will be written.",
)
parser.add_argument(
"--markdown-dir",
type=Path,
default=Path("openapi/markdown"),
help="Directory where split Markdown API docs will be written.",
)
parser.add_argument(
"--keep-swagger-json",
action="store_true",
help="Keep intermediate JSON spec files after Markdown generation.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
written_paths = generate_markdown_docs(
args.swagger_dir,
args.markdown_dir,
keep_swagger_json=args.keep_swagger_json,
)
for path in written_paths:
logger.debug(path)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -9,12 +9,15 @@ which is unnecessary when the goal is only to serialize the Flask-RESTX
from __future__ import annotations
import argparse
import hashlib
import json
import logging
import os
import sys
from collections.abc import MutableMapping
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol, TypeGuard
from flask import Flask
from flask_restx.swagger import Swagger
@ -30,19 +33,110 @@ if str(API_ROOT) not in sys.path:
class SpecTarget:
route: str
filename: str
namespace: str
class RestxApi(Protocol):
models: MutableMapping[str, object]
def model(self, name: str, model: dict[object, object]) -> object: ...
SPEC_TARGETS: tuple[SpecTarget, ...] = (
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json"),
SpecTarget(route="/api/swagger.json", filename="web-swagger.json"),
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json"),
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
)
_ORIGINAL_REGISTER_MODEL = Swagger.register_model
_ORIGINAL_REGISTER_FIELD = Swagger.register_field
def _apply_runtime_defaults() -> None:
def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]:
"""Return whether a nested field map is an anonymous inline mapping."""
from flask_restx.model import Model, OrderedModel
return isinstance(value, dict) and not isinstance(value, (Model, OrderedModel))
def _jsonable_schema_value(value: object) -> object:
"""Return a deterministic JSON-serializable representation for schema fingerprints."""
if value is None or isinstance(value, str | int | float | bool):
return value
if isinstance(value, list | tuple):
return [_jsonable_schema_value(item) for item in value]
if isinstance(value, dict):
return {str(key): _jsonable_schema_value(item) for key, item in value.items()}
value_type = type(value)
return f"<{value_type.__module__}.{value_type.__qualname__}>"
def _field_signature(field: object) -> object:
"""Build a stable signature for a Flask-RESTX field object."""
from flask_restx import fields
from flask_restx.model import instance
field_instance = instance(field)
signature: dict[str, object] = {
"class": f"{field_instance.__class__.__module__}.{field_instance.__class__.__qualname__}"
}
if isinstance(field_instance, fields.Nested):
nested = getattr(field_instance, "nested", None)
if _is_inline_field_map(nested):
signature["nested"] = _inline_model_signature(nested)
else:
signature["nested"] = getattr(
nested,
"name",
f"<{type(nested).__module__}.{type(nested).__qualname__}>",
)
elif hasattr(field_instance, "container"):
signature["container"] = _field_signature(field_instance.container)
else:
schema = getattr(field_instance, "__schema__", None)
if isinstance(schema, dict):
signature["schema"] = _jsonable_schema_value(schema)
for attr_name in (
"attribute",
"default",
"description",
"example",
"max",
"min",
"nullable",
"readonly",
"required",
"title",
):
if hasattr(field_instance, attr_name):
signature[attr_name] = _jsonable_schema_value(getattr(field_instance, attr_name))
return signature
def _inline_model_signature(nested_fields: dict[object, object]) -> object:
"""Build a stable signature for an anonymous inline model."""
return [
(str(field_name), _field_signature(field))
for field_name, field in sorted(nested_fields.items(), key=lambda item: str(item[0]))
]
def _inline_model_name(nested_fields: dict[object, object]) -> str:
"""Return a stable Swagger model name for an anonymous inline field map."""
signature = json.dumps(_inline_model_signature(nested_fields), sort_keys=True, separators=(",", ":"))
digest = hashlib.sha1(signature.encode("utf-8")).hexdigest()[:12]
return f"_AnonymousInlineModel_{digest}"
def apply_runtime_defaults() -> None:
"""Force the small config surface required for Swagger generation."""
os.environ.setdefault("SECRET_KEY", "spec-export")
@ -74,25 +168,26 @@ def _patch_swagger_for_inline_nested_dicts() -> None:
anonymous_models = getattr(self, "_anonymous_inline_models", None)
if anonymous_models is None:
anonymous_models = {}
self._anonymous_inline_models = anonymous_models
self.__dict__["_anonymous_inline_models"] = anonymous_models
anonymous_name = anonymous_models.get(id(nested_fields))
if anonymous_name is None:
anonymous_name = f"_AnonymousInlineModel{len(anonymous_models) + 1}"
anonymous_name = _inline_model_name(nested_fields)
anonymous_models[id(nested_fields)] = anonymous_name
self.api.model(anonymous_name, nested_fields)
if anonymous_name not in self.api.models:
self.api.model(anonymous_name, nested_fields)
return self.api.models[anonymous_name]
def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]:
if isinstance(model, dict):
if _is_inline_field_map(model):
model = get_or_create_inline_model(self, model)
return _ORIGINAL_REGISTER_MODEL(self, model)
def register_field_with_inline_dict_support(self: Swagger, field: object) -> None:
nested = getattr(field, "nested", None)
if isinstance(nested, dict):
if _is_inline_field_map(nested):
field.model = get_or_create_inline_model(self, nested) # type: ignore
_ORIGINAL_REGISTER_FIELD(self, field)
@ -105,22 +200,169 @@ def _patch_swagger_for_inline_nested_dicts() -> None:
def create_spec_app() -> Flask:
"""Build a minimal Flask app that only mounts the Swagger-producing blueprints."""
_apply_runtime_defaults()
apply_runtime_defaults()
_patch_swagger_for_inline_nested_dicts()
app = Flask(__name__)
from controllers.console import bp as console_bp
from controllers.console import console_ns
from controllers.service_api import bp as service_api_bp
from controllers.service_api import service_api_ns
from controllers.web import bp as web_bp
from controllers.web import web_ns
app.register_blueprint(console_bp)
app.register_blueprint(web_bp)
app.register_blueprint(service_api_bp)
for namespace in (console_ns, web_ns, service_api_ns):
for api in namespace.apis:
_materialize_inline_model_definitions(api)
return app
def _registered_models(namespace: str) -> dict[str, object]:
"""Return the Flask-RESTX models registered for a Swagger namespace."""
if namespace == "console":
from controllers.console import console_ns
models = dict(console_ns.models)
for api in console_ns.apis:
models.update(api.models)
return models
if namespace == "web":
from controllers.web import web_ns
models = dict(web_ns.models)
for api in web_ns.apis:
models.update(api.models)
return models
if namespace == "service":
from controllers.service_api import service_api_ns
models = dict(service_api_ns.models)
for api in service_api_ns.apis:
models.update(api.models)
return models
raise ValueError(f"unknown Swagger namespace: {namespace}")
def _materialize_inline_model_definitions(api: RestxApi) -> None:
"""Convert inline `fields.Nested({...})` maps into named API models."""
from flask_restx import fields
from flask_restx.model import Model, OrderedModel, instance
inline_models: dict[int, dict[object, object]] = {}
inline_model_names: dict[int, str] = {}
def collect_field(field: object) -> None:
field_instance = instance(field)
if isinstance(field_instance, fields.Nested):
nested = getattr(field_instance, "nested", None)
if _is_inline_field_map(nested) and id(nested) not in inline_models:
inline_models[id(nested)] = nested
for nested_field in nested.values():
collect_field(nested_field)
container = getattr(field_instance, "container", None)
if container is not None:
collect_field(container)
for model in list(api.models.values()):
if isinstance(model, (Model, OrderedModel)):
for field in model.values():
collect_field(field)
for nested_fields in sorted(inline_models.values(), key=_inline_model_name):
anonymous_name = _inline_model_name(nested_fields)
inline_model_names[id(nested_fields)] = anonymous_name
if anonymous_name not in api.models:
api.model(anonymous_name, nested_fields)
def model_name_for(nested_fields: dict[object, object]) -> str:
anonymous_name = inline_model_names.get(id(nested_fields))
if anonymous_name is None:
anonymous_name = _inline_model_name(nested_fields)
inline_model_names[id(nested_fields)] = anonymous_name
if anonymous_name not in api.models:
api.model(anonymous_name, nested_fields)
return anonymous_name
def materialize_field(field: object) -> None:
field_instance = instance(field)
if isinstance(field_instance, fields.Nested):
nested = getattr(field_instance, "nested", None)
if _is_inline_field_map(nested):
field_instance.model = api.models[model_name_for(nested)] # type: ignore[attr-defined]
container = getattr(field_instance, "container", None)
if container is not None:
materialize_field(container)
index = 0
while index < len(api.models):
model = list(api.models.values())[index]
index += 1
if isinstance(model, (Model, OrderedModel)):
for field in model.values():
materialize_field(field)
def drop_null_values(value: object) -> object:
"""Remove JSON null values that make the Markdown converter crash."""
if isinstance(value, dict):
return {key: drop_null_values(item) for key, item in value.items() if item is not None}
if isinstance(value, list):
return [drop_null_values(item) for item in value]
return value
def sort_openapi_arrays(value: object, *, parent_key: str | None = None) -> object:
"""Sort order-insensitive Swagger arrays so generated Markdown is stable."""
if isinstance(value, dict):
return {key: sort_openapi_arrays(item, parent_key=key) for key, item in value.items()}
if not isinstance(value, list):
return value
sorted_items = [sort_openapi_arrays(item, parent_key=parent_key) for item in value]
if parent_key == "parameters":
return sorted(
sorted_items,
key=lambda item: (
item.get("in", "") if isinstance(item, dict) else "",
item.get("name", "") if isinstance(item, dict) else "",
json.dumps(item, sort_keys=True, default=str),
),
)
if parent_key in {"enum", "required", "schemes", "tags"}:
string_items = [item for item in sorted_items if isinstance(item, str)]
if len(string_items) == len(sorted_items):
return sorted(string_items)
return sorted_items
def _merge_registered_definitions(payload: dict[str, object], namespace: str) -> dict[str, object]:
"""Include registered but route-indirect models in the exported Swagger definitions."""
definitions = payload.setdefault("definitions", {})
if not isinstance(definitions, dict):
raise RuntimeError("unexpected Swagger definitions payload")
for name, model in _registered_models(namespace).items():
schema = getattr(model, "__schema__", None)
if isinstance(schema, dict):
definitions.setdefault(name, schema)
return payload
def generate_specs(output_dir: Path) -> list[Path]:
"""Write all Swagger specs to `output_dir` and return the written paths."""
@ -138,6 +380,9 @@ def generate_specs(output_dir: Path) -> list[Path]:
payload = response.get_json()
if not isinstance(payload, dict):
raise RuntimeError(f"unexpected response payload for {target.route}")
payload = _merge_registered_definitions(payload, target.namespace)
payload = drop_null_values(payload)
payload = sort_openapi_arrays(payload)
output_path = output_dir / target.filename
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")

View File

@ -1,7 +1,9 @@
from flask import Flask
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
def init_app(app):
def init_app(app: Flask):
with app.app_context():
configure_session_factory(db.engine)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import mimetypes
import uuid
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Any, Literal, NotRequired, TypedDict, assert_never, cast
from sqlalchemy import select
@ -19,10 +19,58 @@ from .common import resolve_mapping_file_id
from .remote import get_remote_file_info
from .validation import is_file_valid_with_config
type FileTypeValue = FileType | Literal["image", "document", "audio", "video", "custom"]
type _LocalFileTransferMethod = Literal["local_file", FileTransferMethod.LOCAL_FILE]
type _RemoteUrlTransferMethod = Literal["remote_url", FileTransferMethod.REMOTE_URL]
type _ToolFileTransferMethod = Literal["tool_file", FileTransferMethod.TOOL_FILE]
type _DatasourceFileTransferMethod = Literal["datasource_file", FileTransferMethod.DATASOURCE_FILE]
class LocalFileMapping(TypedDict):
transfer_method: _LocalFileTransferMethod
id: NotRequired[str | None] # Read as the graph-layer File.file_id.
type: NotRequired[FileTypeValue | None] # Read for type override and upload config validation.
upload_file_id: NotRequired[str | None] # File id lookup priority 1.
reference: NotRequired[str | None] # File id lookup priority 2; may be an opaque file reference.
related_id: NotRequired[str | None] # File id lookup priority 3; legacy persisted field.
class RemoteUrlMapping(TypedDict):
transfer_method: _RemoteUrlTransferMethod
id: NotRequired[str | None] # Read as the graph-layer File.file_id.
type: NotRequired[FileTypeValue | None] # Read for type override and upload config validation.
upload_file_id: NotRequired[str | None] # Persisted UploadFile lookup priority 1.
reference: NotRequired[str | None] # Persisted UploadFile lookup priority 2; may be an opaque file reference.
related_id: NotRequired[str | None] # Persisted UploadFile lookup priority 3; legacy persisted field.
url: NotRequired[str | None] # External URL lookup priority 1 when no UploadFile id is resolved.
remote_url: NotRequired[str | None] # External URL lookup priority 2 when no UploadFile id is resolved.
class ToolFileMapping(TypedDict):
transfer_method: _ToolFileTransferMethod
id: NotRequired[str | None] # Read as the graph-layer File.file_id.
type: NotRequired[FileTypeValue | None] # Read for type override and upload config validation.
tool_file_id: NotRequired[str | None] # ToolFile lookup priority 1.
reference: NotRequired[str | None] # ToolFile lookup priority 2; may be an opaque file reference.
related_id: NotRequired[str | None] # ToolFile lookup priority 3; legacy persisted field.
class DatasourceFileMapping(TypedDict):
transfer_method: _DatasourceFileTransferMethod
type: NotRequired[FileTypeValue | None] # Read for type override and upload config validation.
datasource_file_id: NotRequired[str | None] # UploadFile lookup priority 1 for datasource-backed files.
reference: NotRequired[str | None] # UploadFile lookup priority 2; may be an opaque file reference.
related_id: NotRequired[str | None] # UploadFile lookup priority 3; legacy persisted field.
type FileMapping = LocalFileMapping | RemoteUrlMapping | ToolFileMapping | DatasourceFileMapping
type FileMappingInput = FileMapping | Mapping[str, Any]
def build_from_mapping(
*,
mapping: Mapping[str, Any],
mapping: FileMappingInput,
tenant_id: str,
config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
@ -32,18 +80,45 @@ def build_from_mapping(
if not transfer_method_value:
raise ValueError("transfer_method is required in file mapping")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
build_func = _get_build_function(transfer_method)
file = build_func(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
transfer_method = FileTransferMethod.value_of(str(transfer_method_value))
match transfer_method:
case FileTransferMethod.LOCAL_FILE:
file = _build_from_local_file(
mapping=cast(LocalFileMapping, mapping),
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
case FileTransferMethod.REMOTE_URL:
file = _build_from_remote_url(
mapping=cast(RemoteUrlMapping, mapping),
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
case FileTransferMethod.TOOL_FILE:
file = _build_from_tool_file(
mapping=cast(ToolFileMapping, mapping),
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
case FileTransferMethod.DATASOURCE_FILE:
file = _build_from_datasource_file(
mapping=cast(DatasourceFileMapping, mapping),
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
access_controller=access_controller,
)
case _:
assert_never(transfer_method)
if config and not is_file_valid_with_config(
input_file_type=mapping.get("type", FileType.CUSTOM),
input_file_type=mapping.get("type") or FileType.CUSTOM,
file_extension=file.extension or "",
file_transfer_method=file.transfer_method,
config=config,
@ -87,19 +162,6 @@ def build_from_mappings(
return files
def _get_build_function(transfer_method: FileTransferMethod):
build_functions = {
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file,
}
build_func = build_functions.get(transfer_method)
if build_func is None:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
return build_func
def _resolve_file_type(
*,
detected_file_type: FileType,
@ -116,7 +178,7 @@ def _resolve_file_type(
def _build_from_local_file(
*,
mapping: Mapping[str, Any],
mapping: LocalFileMapping,
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
@ -163,7 +225,7 @@ def _build_from_local_file(
def _build_from_remote_url(
*,
mapping: Mapping[str, Any],
mapping: RemoteUrlMapping,
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
@ -235,7 +297,7 @@ def _build_from_remote_url(
def _build_from_tool_file(
*,
mapping: Mapping[str, Any],
mapping: ToolFileMapping,
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
@ -278,7 +340,7 @@ def _build_from_tool_file(
def _build_from_datasource_file(
*,
mapping: Mapping[str, Any],
mapping: DatasourceFileMapping,
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
@ -298,7 +360,7 @@ def _build_from_datasource_file(
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
detected_file_type = standardize_file_type(extension=extension, mime_type=datasource_file.mime_type)
file_type = _resolve_file_type(
detected_file_type=detected_file_type,
specified_type=mapping.get("type"),

View File

@ -1,9 +0,0 @@
from typing import TypeGuard
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
return isinstance(v, dict)
def is_str(v: object) -> TypeGuard[str]:
return isinstance(v, str)

View File

@ -0,0 +1,26 @@
"""add recommended app categories
Revision ID: a4f2d8c9b731
Revises: 227822d22895
Create Date: 2026-04-29 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "a4f2d8c9b731"
down_revision = "227822d22895"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("recommended_apps", schema=None) as batch_op:
batch_op.add_column(sa.Column("categories", sa.JSON(), nullable=True))
def downgrade():
with op.batch_alter_table("recommended_apps", schema=None) as batch_op:
batch_op.drop_column("categories")

View File

@ -0,0 +1,64 @@
"""Add human input upload token and file association tables
Revision ID: 8d4c2a1b9f03
Revises: 227822d22895
Create Date: 2026-05-06 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models
# revision identifiers, used by Alembic.
revision = "8d4c2a1b9f03"
down_revision = "227822d22895"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"human_input_form_upload_tokens",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("form_id", models.types.StringUUID(), nullable=False),
sa.Column("recipient_id", models.types.StringUUID(), nullable=False),
sa.Column("token", sa.String(length=255), nullable=False),
sa.PrimaryKeyConstraint("id", name="human_input_form_upload_tokens_pkey"),
sa.UniqueConstraint("token", name="human_input_form_upload_tokens_token_key"),
)
with op.batch_alter_table("human_input_form_upload_tokens", schema=None) as batch_op:
batch_op.create_index("human_input_form_upload_tokens_form_id_idx", ["form_id"], unique=False)
op.create_table(
"human_input_form_upload_files",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("form_id", models.types.StringUUID(), nullable=False),
sa.Column("upload_file_id", models.types.StringUUID(), nullable=False),
sa.Column("upload_token_id", models.types.StringUUID(), nullable=False),
sa.PrimaryKeyConstraint("id", name="human_input_form_upload_files_pkey"),
sa.UniqueConstraint("upload_file_id", name="human_input_form_upload_files_upload_file_id_key"),
)
with op.batch_alter_table("human_input_form_upload_files", schema=None) as batch_op:
batch_op.create_index("human_input_form_upload_files_form_id_idx", ["form_id"], unique=False)
batch_op.create_index("human_input_form_upload_files_upload_token_id_idx", ["upload_token_id"], unique=False)
def downgrade():
with op.batch_alter_table("human_input_form_upload_files", schema=None) as batch_op:
batch_op.drop_index("human_input_form_upload_files_upload_token_id_idx")
batch_op.drop_index("human_input_form_upload_files_form_id_idx")
op.drop_table("human_input_form_upload_files")
with op.batch_alter_table("human_input_form_upload_tokens", schema=None) as batch_op:
batch_op.drop_index("human_input_form_upload_tokens_form_id_idx")
op.drop_table("human_input_form_upload_tokens")

View File

@ -46,7 +46,7 @@ from .evaluation import (
EvaluationTargetType,
)
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
from .human_input import HumanInputForm
from .human_input import HumanInputForm, HumanInputFormUploadFile, HumanInputFormUploadToken
from .model import (
AccountTrialAppRecord,
ApiRequest,
@ -182,6 +182,8 @@ __all__ = [
"ExternalKnowledgeBindings",
"HumanInputContent",
"HumanInputForm",
"HumanInputFormUploadFile",
"HumanInputFormUploadToken",
"IconType",
"InstalledApp",
"InvitationCode",

View File

@ -251,3 +251,55 @@ class HumanInputFormRecipient(DefaultFieldsMixin, Base):
access_token=_generate_token(),
)
return recipient_model
class HumanInputFormUploadToken(DefaultFieldsMixin, Base):
"""Upload authorization token bound to one human input form recipient.
HITL upload tokens are intentionally separate from app/service bearer tokens.
The token is stored as an opaque random value so upload endpoints can perform
a direct lookup without entering the normal Web App authentication chain.
Upload ownership is resolved from the form's workflow run initiator instead
of being persisted on the token row itself.
"""
__tablename__ = "human_input_form_upload_tokens"
__table_args__ = (
sa.UniqueConstraint("token", name="human_input_form_upload_tokens_token_key"),
sa.Index("human_input_form_upload_tokens_form_id_idx", "form_id"),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
form_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
recipient_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
token: Mapped[str] = mapped_column(sa.String(255), nullable=False)
form: Mapped[HumanInputForm] = relationship(
"HumanInputForm",
uselist=False,
foreign_keys=[form_id],
primaryjoin="foreign(HumanInputFormUploadToken.form_id) == HumanInputForm.id",
lazy="raise",
)
class HumanInputFormUploadFile(DefaultFieldsMixin, Base):
"""Association between a human input form and a file uploaded through its token.
Ownership remains on ``UploadFile`` itself; this table only records the
durable form/token/file linkage needed by Human Input flows.
"""
__tablename__ = "human_input_form_upload_files"
__table_args__ = (
sa.UniqueConstraint("upload_file_id", name="human_input_form_upload_files_upload_file_id_key"),
sa.Index("human_input_form_upload_files_form_id_idx", "form_id"),
sa.Index("human_input_form_upload_files_upload_token_id_idx", "upload_token_id"),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
form_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
upload_token_id: Mapped[str] = mapped_column(StringUUID, nullable=False)

View File

@ -878,6 +878,7 @@ class RecommendedApp(TypeBase):
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
category: Mapped[str] = mapped_column(String(255), nullable=False)
categories: Mapped[list[str] | None] = mapped_column(sa.JSON, nullable=True, default=None)
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)

View File

@ -9,11 +9,11 @@ import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy.orm import Mapped, mapped_column
from core.db.session_factory import session_factory
from graphon.model_runtime.entities.model_entities import ModelType
from libs.uuid_utils import uuidv7
from .base import TypeBase
from .engine import db
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
from .types import EnumText, LongText, StringUUID
@ -82,7 +82,8 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
with session_factory.create_session() as session:
return session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
@property
def credential_name(self):
@ -145,9 +146,10 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
with session_factory.create_session() as session:
return session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
)
@property
def credential_name(self):

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@ from collections.abc import Mapping
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from dify_trace_aliyun.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
@ -31,7 +32,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from models import EndUser
def test_get_user_id_from_message_data_no_end_user(monkeypatch):
def test_get_user_id_from_message_data_no_end_user(monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = None
@ -39,7 +40,7 @@ def test_get_user_id_from_message_data_no_end_user(monkeypatch):
assert get_user_id_from_message_data(message_data) == "account_id"
def test_get_user_id_from_message_data_with_end_user(monkeypatch):
def test_get_user_id_from_message_data_with_end_user(monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = "end_user_id"
@ -57,7 +58,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
assert get_user_id_from_message_data(message_data) == "session_id"
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.from_account_id = "account_id"
message_data.from_end_user_id = "end_user_id"
@ -111,7 +112,7 @@ def test_get_workflow_node_status():
assert status.status_code == StatusCode.UNSET
def test_create_links_from_trace_id(monkeypatch):
def test_create_links_from_trace_id(monkeypatch: pytest.MonkeyPatch):
# Mock create_link
mock_link = MagicMock(spec=Link)
import dify_trace_aliyun.data_exporter.traceclient

View File

@ -40,7 +40,7 @@ def langfuse_config():
@pytest.fixture
def trace_instance(langfuse_config, monkeypatch):
def trace_instance(langfuse_config, monkeypatch: pytest.MonkeyPatch):
# Mock Langfuse client to avoid network calls
mock_client = MagicMock()
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
@ -49,7 +49,7 @@ def trace_instance(langfuse_config, monkeypatch):
return instance
def test_init(langfuse_config, monkeypatch):
def test_init(langfuse_config, monkeypatch: pytest.MonkeyPatch):
mock_langfuse = MagicMock()
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
monkeypatch.setenv("FILES_URL", "http://test.url")
@ -64,7 +64,7 @@ def test_init(langfuse_config, monkeypatch):
assert instance.file_base_url == "http://test.url"
def test_trace_dispatch(trace_instance, monkeypatch):
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
methods = [
"workflow_trace",
"message_trace",
@ -114,7 +114,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
def test_workflow_trace_with_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Setup trace info
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
@ -218,7 +218,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
assert other_span.level == LevelEnum.ERROR
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
def test_workflow_trace_no_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
tenant_id="tenant-1",
@ -259,7 +259,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
assert trace_data.name == TraceTaskName.WORKFLOW_TRACE
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",
tenant_id="tenant-1",
@ -287,7 +287,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_instance.workflow_trace(trace_info)
def test_message_trace_basic(trace_instance, monkeypatch):
def test_message_trace_basic(trace_instance, monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
@ -331,7 +331,7 @@ def test_message_trace_basic(trace_instance, monkeypatch):
assert gen_data.usage.total == 30
def test_message_trace_with_end_user(trace_instance, monkeypatch):
def test_message_trace_with_end_user(trace_instance, monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
@ -636,7 +636,7 @@ def test_langfuse_trace_entity_with_list_dict_input():
assert data.input[0]["content"] == "hello"
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog):
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
# Setup trace info to trigger LLM node usage extraction
trace_info = WorkflowTraceInfo(
workflow_id="wf-1",

View File

@ -35,7 +35,7 @@ def langsmith_config():
@pytest.fixture
def trace_instance(langsmith_config, monkeypatch):
def trace_instance(langsmith_config, monkeypatch: pytest.MonkeyPatch):
# Mock LangSmith client
mock_client = MagicMock()
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
@ -44,7 +44,7 @@ def trace_instance(langsmith_config, monkeypatch):
return instance
def test_init(langsmith_config, monkeypatch):
def test_init(langsmith_config, monkeypatch: pytest.MonkeyPatch):
mock_client_class = MagicMock()
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
monkeypatch.setenv("FILES_URL", "http://test.url")
@ -57,7 +57,7 @@ def test_init(langsmith_config, monkeypatch):
assert instance.file_base_url == "http://test.url"
def test_trace_dispatch(trace_instance, monkeypatch):
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
methods = [
"workflow_trace",
"message_trace",
@ -107,7 +107,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace(trace_instance, monkeypatch):
def test_workflow_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Setup trace info
workflow_data = MagicMock()
workflow_data.created_at = _dt()
@ -223,7 +223,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
assert call_args[4].run_type == LangSmithRunType.retriever
def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
def test_workflow_trace_no_start_time(trace_instance, monkeypatch: pytest.MonkeyPatch):
workflow_data = MagicMock()
workflow_data.created_at = _dt()
workflow_data.finished_at = _dt() + timedelta(seconds=1)
@ -266,7 +266,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
assert trace_instance.add_run.called
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.trace_id = "trace-1"
trace_info.message_id = None
@ -290,7 +290,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_instance.workflow_trace(trace_info)
def test_message_trace(trace_instance, monkeypatch):
def test_message_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.id = "msg-1"
message_data.from_account_id = "acc-1"
@ -516,7 +516,7 @@ def test_update_run_error(trace_instance):
trace_instance.update_run(update_data)
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog):
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
workflow_data = MagicMock()
workflow_data.created_at = _dt()
workflow_data.finished_at = _dt() + timedelta(seconds=1)

View File

@ -614,7 +614,7 @@ class TestMessageTrace:
span.set_status.assert_called_once()
span.add_event.assert_called_once()
def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch):
def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch: pytest.MonkeyPatch):
span = MagicMock()
mock_tracing["start"].return_value = span
mock_tracing["set"].return_value = "token"

View File

@ -35,7 +35,7 @@ def opik_config():
@pytest.fixture
def trace_instance(opik_config, monkeypatch):
def trace_instance(opik_config, monkeypatch: pytest.MonkeyPatch):
mock_client = MagicMock()
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client)
@ -65,7 +65,7 @@ def test_prepare_opik_uuid():
assert result is not None
def test_init(opik_config, monkeypatch):
def test_init(opik_config, monkeypatch: pytest.MonkeyPatch):
mock_opik = MagicMock()
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik)
monkeypatch.setenv("FILES_URL", "http://test.url")
@ -82,7 +82,7 @@ def test_init(opik_config, monkeypatch):
assert instance.project == opik_config.project
def test_trace_dispatch(trace_instance, monkeypatch):
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
methods = [
"workflow_trace",
"message_trace",
@ -132,7 +132,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
mocks["generate_name_trace"].assert_called_once_with(info)
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
def test_workflow_trace_with_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Define constants for better readability
WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce"
WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef"
@ -221,7 +221,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
assert trace_instance.add_span.call_count >= 1
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
def test_workflow_trace_no_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Define constants for better readability
WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3"
WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac"
@ -265,7 +265,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
trace_instance.add_trace.assert_called_once()
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
trace_info = WorkflowTraceInfo(
workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a",
tenant_id="tenant-1",
@ -293,7 +293,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_instance.workflow_trace(trace_info)
def test_message_trace_basic(trace_instance, monkeypatch):
def test_message_trace_basic(trace_instance, monkeypatch: pytest.MonkeyPatch):
# Define constants for better readability
MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab"
CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3"
@ -340,7 +340,7 @@ def test_message_trace_basic(trace_instance, monkeypatch):
trace_instance.add_span.assert_called_once()
def test_message_trace_with_end_user(trace_instance, monkeypatch):
def test_message_trace_with_end_user(trace_instance, monkeypatch: pytest.MonkeyPatch):
message_data = MagicMock()
message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e"
message_data.from_account_id = "acc-1"
@ -614,7 +614,7 @@ def test_get_project_url_error(trace_instance):
trace_instance.get_project_url()
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog):
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
trace_info = WorkflowTraceInfo(
workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de",
tenant_id="66e8e918-472e-4b69-8051-12502c34fc07",

View File

@ -267,14 +267,14 @@ class TestInit:
with pytest.raises(ValueError, match="Weave login failed"):
WeaveDataTrace(config)
def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch):
def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch: pytest.MonkeyPatch):
"""Test FILES_URL is read from environment."""
monkeypatch.setenv("FILES_URL", "http://files.example.com")
config = _make_weave_config()
instance = WeaveDataTrace(config)
assert instance.file_base_url == "http://files.example.com"
def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch):
def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch: pytest.MonkeyPatch):
"""Test FILES_URL defaults to http://127.0.0.1:5001."""
monkeypatch.delenv("FILES_URL", raising=False)
config = _make_weave_config()
@ -302,7 +302,7 @@ class TestGetProjectUrl:
url = instance.get_project_url()
assert url == "https://wandb.ai/my-project"
def test_get_project_url_exception_raises(self, trace_instance, monkeypatch):
def test_get_project_url_exception_raises(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Raises ValueError when exception occurs in get_project_url."""
monkeypatch.setattr(trace_instance, "entity", None)
monkeypatch.setattr(trace_instance, "project_name", None)
@ -583,7 +583,7 @@ class TestFinishCall:
class TestWorkflowTrace:
def _setup_repo(self, monkeypatch, nodes=None):
def _setup_repo(self, monkeypatch: pytest.MonkeyPatch, nodes=None):
"""Helper to patch session/repo dependencies."""
if nodes is None:
nodes = []
@ -599,7 +599,7 @@ class TestWorkflowTrace:
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
return repo
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch):
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Workflow trace with no nodes and no message_id."""
self._setup_repo(monkeypatch, nodes=[])
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@ -614,7 +614,7 @@ class TestWorkflowTrace:
assert trace_instance.start_call.call_count == 1
assert trace_instance.finish_call.call_count == 1
def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch):
def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Workflow trace with message_id creates both message and workflow runs."""
self._setup_repo(monkeypatch, nodes=[])
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@ -629,7 +629,7 @@ class TestWorkflowTrace:
assert trace_instance.start_call.call_count == 2
assert trace_instance.finish_call.call_count == 2
def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch):
def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Workflow trace iterates node executions and creates node runs."""
node = _make_node(
id="node-1",
@ -652,7 +652,7 @@ class TestWorkflowTrace:
# workflow run + node run = 2 calls
assert trace_instance.start_call.call_count == 2
def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch):
def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""LLM node uses process_data prompts as inputs."""
node = _make_node(
node_type=BuiltinNodeTypes.LLM,
@ -680,7 +680,7 @@ class TestWorkflowTrace:
# The key "messages" should be present (validator transforms the list)
assert "messages" in node_run.inputs
def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch):
def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Non-LLM node uses node_execution.inputs directly."""
node = _make_node(
node_type=BuiltinNodeTypes.TOOL,
@ -701,7 +701,7 @@ class TestWorkflowTrace:
node_run = node_call_args[0][0]
assert node_run.inputs.get("tool_input") == "val"
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch):
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Raises ValueError when app_id is missing from metadata."""
monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
@ -714,7 +714,7 @@ class TestWorkflowTrace:
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch):
def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""start_time defaults to datetime.now() when None."""
self._setup_repo(monkeypatch, nodes=[])
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@ -727,7 +727,7 @@ class TestWorkflowTrace:
assert trace_instance.start_call.call_count == 1
def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch):
def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Node with created_at=None uses datetime.now()."""
node = _make_node(created_at=None, elapsed_time=0.5)
self._setup_repo(monkeypatch, nodes=[node])
@ -740,7 +740,7 @@ class TestWorkflowTrace:
trace_instance.workflow_trace(trace_info)
assert trace_instance.start_call.call_count == 2
def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch):
def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Chat mode LLM node adds ls_provider and ls_model_name to attributes."""
node = _make_node(
node_type=BuiltinNodeTypes.LLM,
@ -765,7 +765,7 @@ class TestWorkflowTrace:
assert node_run.attributes.get("ls_provider") == "openai"
assert node_run.attributes.get("ls_model_name") == "gpt-4"
def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch):
def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""Nodes are sorted by created_at before processing."""
node1 = _make_node(id="node-b", created_at=_dt() + timedelta(seconds=2))
node2 = _make_node(id="node-a", created_at=_dt())
@ -799,7 +799,7 @@ class TestMessageTrace:
trace_instance.message_trace(trace_info)
trace_instance.start_call.assert_not_called()
def test_basic_message_trace(self, trace_instance, monkeypatch):
def test_basic_message_trace(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace creates message run and llm child run."""
monkeypatch.setattr(
"dify_trace_weave.weave_trace.db.session.get",
@ -816,7 +816,7 @@ class TestMessageTrace:
assert trace_instance.start_call.call_count == 2
assert trace_instance.finish_call.call_count == 2
def test_message_trace_with_file_data(self, trace_instance, monkeypatch):
def test_message_trace_with_file_data(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace appends file URL to file_list."""
file_data = MagicMock()
file_data.url = "path/to/file.png"
@ -839,7 +839,7 @@ class TestMessageTrace:
message_run = trace_instance.start_call.call_args_list[0][0][0]
assert "http://files.test/path/to/file.png" in message_run.file_list
def test_message_trace_with_end_user(self, trace_instance, monkeypatch):
def test_message_trace_with_end_user(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace looks up end user and sets end_user_id attribute."""
end_user = MagicMock()
end_user.session_id = "session-xyz"
@ -862,7 +862,7 @@ class TestMessageTrace:
message_run = trace_instance.start_call.call_args_list[0][0][0]
assert message_run.attributes.get("end_user_id") == "session-xyz"
def test_message_trace_no_end_user(self, trace_instance, monkeypatch):
def test_message_trace_no_end_user(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace handles when from_end_user_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
@ -880,7 +880,7 @@ class TestMessageTrace:
trace_instance.message_trace(trace_info)
assert trace_instance.start_call.call_count == 2
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch):
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""trace_id falls back to message_id when trace_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
@ -895,7 +895,7 @@ class TestMessageTrace:
message_run = trace_instance.start_call.call_args_list[0][0][0]
assert message_run.id == "msg-1"
def test_message_trace_file_list_none(self, trace_instance, monkeypatch):
def test_message_trace_file_list_none(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
"""message_trace handles file_list=None gracefully."""
mock_db = MagicMock()
mock_db.session.get.return_value = None

View File

@ -20,7 +20,7 @@ def test_validate_distance_function_rejects_unsupported_values():
factory._validate_distance_function("dot_product")
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch):
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch: pytest.MonkeyPatch):
factory = AlibabaCloudMySQLVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",
@ -45,7 +45,7 @@ def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch):
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch: pytest.MonkeyPatch):
factory = AlibabaCloudMySQLVectorFactory()
dataset = SimpleNamespace(
id="dataset-2",

View File

@ -83,7 +83,7 @@ def test_get_type_is_analyticdb():
assert vector.get_type() == "analyticdb"
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch: pytest.MonkeyPatch):
factory = AnalyticdbVectorFactory()
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
@ -109,7 +109,7 @@ def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
assert dataset.index_struct is not None
def test_factory_builds_sql_config_when_host_is_present(monkeypatch):
def test_factory_builds_sql_config_when_host_is_present(monkeypatch: pytest.MonkeyPatch):
factory = AnalyticdbVectorFactory()
dataset = SimpleNamespace(
id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None

View File

@ -24,7 +24,7 @@ def _request_class(name: str):
return _Request
def _install_openapi_stubs(monkeypatch):
def _install_openapi_stubs(monkeypatch: pytest.MonkeyPatch):
gpdb_package = types.ModuleType("alibabacloud_gpdb20160503")
gpdb_package.__path__ = []
gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models")
@ -130,7 +130,7 @@ def test_openapi_config_to_client_params():
assert params["read_timeout"] == 60000
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
initialize_mock = MagicMock()
monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock)
@ -145,7 +145,7 @@ def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
initialize_mock.assert_called_once_with()
def test_initialize_skips_when_cached(monkeypatch):
def test_initialize_skips_when_cached(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -164,7 +164,7 @@ def test_initialize_skips_when_cached(monkeypatch):
vector._create_namespace_if_not_exists.assert_not_called()
def test_initialize_runs_when_cache_is_missing(monkeypatch):
def test_initialize_runs_when_cache_is_missing(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -184,7 +184,7 @@ def test_initialize_runs_when_cache_is_missing(monkeypatch):
openapi_module.redis_client.set.assert_called_once()
def test_initialize_vector_database_calls_openapi_client(monkeypatch):
def test_initialize_vector_database_calls_openapi_client(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
@ -199,7 +199,7 @@ def test_initialize_vector_database_calls_openapi_client(monkeypatch):
assert request.manager_account_password == "password"
def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
def test_create_namespace_creates_when_namespace_not_found(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
@ -211,7 +211,7 @@ def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
vector._client.create_namespace.assert_called_once()
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
@ -222,7 +222,7 @@ def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
vector._create_namespace_if_not_exists()
def test_create_namespace_noop_when_namespace_exists(monkeypatch):
def test_create_namespace_noop_when_namespace_exists(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector.config = _config()
@ -234,7 +234,7 @@ def test_create_namespace_noop_when_namespace_exists(monkeypatch):
vector._client.create_namespace.assert_not_called()
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
lock = MagicMock()
lock.__enter__.return_value = None
@ -255,7 +255,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
openapi_module.redis_client.set.assert_called_once()
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -274,7 +274,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
vector._client.create_collection.assert_not_called()
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch: pytest.MonkeyPatch):
stubs = _install_openapi_stubs(monkeypatch)
lock = MagicMock()
lock.__enter__.return_value = None
@ -293,7 +293,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
vector.create_collection_if_not_exists(embedding_dimension=512)
def test_openapi_add_delete_and_search_methods(monkeypatch):
def test_openapi_add_delete_and_search_methods(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
@ -348,7 +348,7 @@ def test_openapi_add_delete_and_search_methods(monkeypatch):
assert docs_by_text[0].page_content == "high"
def test_text_exists_returns_false_when_matches_empty(monkeypatch):
def test_text_exists_returns_false_when_matches_empty(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
@ -361,7 +361,7 @@ def test_text_exists_returns_false_when_matches_empty(monkeypatch):
assert vector.text_exists("missing-id") is False
def test_openapi_delete_success(monkeypatch):
def test_openapi_delete_success(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"
@ -372,7 +372,7 @@ def test_openapi_delete_success(monkeypatch):
vector._client.delete_collection.assert_called_once()
def test_openapi_delete_propagates_errors(monkeypatch):
def test_openapi_delete_propagates_errors(monkeypatch: pytest.MonkeyPatch):
_install_openapi_stubs(monkeypatch)
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
vector._collection_name = "collection_1"

View File

@ -53,7 +53,7 @@ def test_sql_config_rejects_min_connection_greater_than_max_connection():
AnalyticdbVectorBySqlConfig.model_validate(values)
def test_initialize_skips_when_cache_exists(monkeypatch):
def test_initialize_skips_when_cache_exists(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -70,7 +70,7 @@ def test_initialize_skips_when_cache_exists(monkeypatch):
vector._initialize_vector_database.assert_not_called()
def test_initialize_runs_when_cache_is_missing(monkeypatch):
def test_initialize_runs_when_cache_is_missing(monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -88,7 +88,7 @@ def test_initialize_runs_when_cache_is_missing(monkeypatch):
sql_module.redis_client.set.assert_called_once()
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch):
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
@ -119,7 +119,7 @@ def test_get_cursor_context_manager_handles_connection_lifecycle():
pool.putconn.assert_called_once_with(connection)
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch):
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.table_name = "dify.collection"
@ -273,7 +273,7 @@ def test_delete_drops_table():
cursor.execute.assert_called_once()
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch):
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch: pytest.MonkeyPatch):
config = AnalyticdbVectorBySqlConfig(**_config_values())
created_pool = MagicMock()
@ -288,7 +288,7 @@ def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypat
assert vector.pool is created_pool
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch):
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
@ -326,7 +326,7 @@ def test_initialize_vector_database_handles_existing_database_and_search_config(
assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list)
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch):
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector.databaseName = "knowledgebase"
@ -353,7 +353,7 @@ def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(mon
worker_connection.rollback.assert_called_once()
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch):
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._collection_name = "collection"
@ -381,7 +381,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp
sql_module.redis_client.set.assert_called_once()
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch):
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch: pytest.MonkeyPatch):
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
vector._collection_name = "collection"

View File

@ -121,7 +121,7 @@ def _build_fake_pymochow_modules():
@pytest.fixture
def baidu_module(monkeypatch):
def baidu_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_pymochow_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import dify_vdb_baidu.baidu_vector as module
@ -254,7 +254,7 @@ def test_search_methods_delegate_to_database_table(baidu_module):
assert vector._get_search_res.call_count == 2
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch):
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch: pytest.MonkeyPatch):
factory = baidu_module.BaiduVectorFactory()
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
@ -279,7 +279,7 @@ def test_factory_initializes_collection_name_and_index_struct(baidu_module, monk
assert dataset.index_struct is not None
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch):
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch: pytest.MonkeyPatch):
init_client = MagicMock(return_value="client")
init_database = MagicMock(return_value="database")
monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client)
@ -372,7 +372,7 @@ def test_get_search_result_handles_invalid_metadata_json(baidu_module):
assert "document_id" not in docs[0].metadata
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch):
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch: pytest.MonkeyPatch):
credentials = MagicMock(return_value="credentials")
configuration = MagicMock(return_value="configuration")
client_cls = MagicMock(return_value="client")
@ -411,7 +411,7 @@ def test_init_database_raises_for_unknown_create_database_error(baidu_module):
vector._init_database()
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch):
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch: pytest.MonkeyPatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._client_config = SimpleNamespace(
@ -460,7 +460,7 @@ def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypat
vector._wait_for_index_ready.assert_called_once_with(table, 3600)
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch):
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch: pytest.MonkeyPatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._db = MagicMock()
@ -493,7 +493,7 @@ def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypat
vector._create_table(3)
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch):
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch: pytest.MonkeyPatch):
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
vector._collection_name = "collection_1"
vector._client_config = SimpleNamespace(
@ -524,7 +524,9 @@ def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module,
vector._create_table(3)
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch):
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(
baidu_module, monkeypatch: pytest.MonkeyPatch
):
factory = baidu_module.BaiduVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",

View File

@ -44,7 +44,7 @@ def _build_fake_chroma_modules():
@pytest.fixture
def chroma_module(monkeypatch):
def chroma_module(monkeypatch: pytest.MonkeyPatch):
fake_chroma = _build_fake_chroma_modules()
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
import dify_vdb_chroma.chroma_vector as module
@ -73,7 +73,7 @@ def test_chroma_config_to_params_builds_expected_payload(chroma_module):
assert params["settings"].chroma_client_auth_credentials == "credentials"
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch):
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -173,7 +173,7 @@ def test_search_by_full_text_returns_empty_list(chroma_module):
assert vector.search_by_full_text("query") == []
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch):
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch: pytest.MonkeyPatch):
factory = chroma_module.ChromaVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None

View File

@ -45,7 +45,7 @@ def _build_fake_clickzetta_module():
@pytest.fixture
def clickzetta_module(monkeypatch):
def clickzetta_module(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
import dify_vdb_clickzetta.clickzetta_vector as module
@ -218,7 +218,7 @@ def test_search_by_like_returns_documents_with_default_score(clickzetta_module):
assert docs[0].metadata["score"] == 0.5
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
factory = clickzetta_module.ClickzettaVectorFactory()
dataset = SimpleNamespace(id="dataset-1")
@ -243,7 +243,7 @@ def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
assert vector_cls.call_args.kwargs["collection_name"] == "collection"
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch):
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
clickzetta_module.ClickzettaConnectionPool._instance = None
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
@ -255,7 +255,7 @@ def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch
assert "username:instance:service:workspace:cluster:dify" in key
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch):
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
@ -274,7 +274,7 @@ def test_connection_pool_create_connection_retries_and_configures(clickzetta_mod
pool._configure_connection.assert_called_once_with(connection)
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch):
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
@ -318,7 +318,7 @@ def test_connection_pool_configure_connection_swallows_errors(clickzetta_module)
monkeypatch.undo()
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch):
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
pool = clickzetta_module.ClickzettaConnectionPool()
config = _config(clickzetta_module)
@ -360,7 +360,7 @@ def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monk
assert pool._shutdown is True
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch):
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._shutdown = False
pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True))
@ -384,7 +384,7 @@ def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module
pool._cleanup_expired_connections.assert_called_once()
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch):
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
pool.get_connection.return_value = "conn"
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool))
@ -405,7 +405,7 @@ def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypat
assert vector._ensure_connection() == "conn"
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch):
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
class _Thread:
def __init__(self, target, daemon):
self.target = target
@ -579,7 +579,7 @@ def test_create_inverted_index_branches(clickzetta_module):
vector._create_inverted_index(cursor)
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch):
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
vector._config = _config(clickzetta_module)
vector._config.batch_size = 2
@ -811,7 +811,7 @@ def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module):
assert pool._shutdown is True
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch):
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
pool._shutdown = False

View File

@ -150,7 +150,7 @@ def _build_fake_couchbase_modules():
@pytest.fixture
def couchbase_module(monkeypatch):
def couchbase_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_couchbase_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -194,7 +194,7 @@ def test_init_sets_cluster_handles(couchbase_module):
vector._cluster.wait_until_ready.assert_called_once()
def test_create_and_create_collection_branches(couchbase_module, monkeypatch):
def test_create_and_create_collection_branches(couchbase_module, monkeypatch: pytest.MonkeyPatch):
vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector)
vector._collection_name = "collection_1"
vector._client_config = _config(couchbase_module)
@ -319,7 +319,7 @@ def test_search_methods_and_format_metadata(couchbase_module):
assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2}
def test_delete_collection_and_factory(couchbase_module, monkeypatch):
def test_delete_collection_and_factory(couchbase_module, monkeypatch: pytest.MonkeyPatch):
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
scopes = [
SimpleNamespace(collections=[SimpleNamespace(name="other")]),

View File

@ -28,7 +28,7 @@ def _build_fake_elasticsearch_modules():
@pytest.fixture
def elasticsearch_ja_module(monkeypatch):
def elasticsearch_ja_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -39,7 +39,7 @@ def elasticsearch_ja_module(monkeypatch):
return importlib.reload(ja_module)
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -57,7 +57,7 @@ def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
elasticsearch_ja_module.redis_client.set.assert_not_called()
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch):
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -87,7 +87,7 @@ def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monk
elasticsearch_ja_module.redis_client.set.assert_called_once()
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch):
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -38,7 +38,7 @@ def _build_fake_elasticsearch_modules():
@pytest.fixture
def elasticsearch_module(monkeypatch):
def elasticsearch_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -287,7 +287,7 @@ def test_search_by_vector_and_full_text(elasticsearch_module):
assert "bool" in query
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -331,7 +331,7 @@ def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
elasticsearch_module.redis_client.set.assert_called_once()
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch):
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch: pytest.MonkeyPatch):
factory = elasticsearch_module.ElasticSearchVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -38,7 +38,7 @@ def _build_fake_hologres_modules():
@pytest.fixture
def hologres_module(monkeypatch):
def hologres_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_hologres_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -266,7 +266,7 @@ def test_delete_handles_existing_and_missing_tables(hologres_module):
vector._client.drop_table.assert_called_once_with(vector.table_name)
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch):
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
@ -281,7 +281,7 @@ def test_create_collection_returns_early_when_cache_hits(hologres_module, monkey
hologres_module.redis_client.set.assert_not_called()
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch):
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
@ -313,7 +313,7 @@ def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatc
hologres_module.redis_client.set.assert_called_once()
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch):
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = False
@ -331,7 +331,7 @@ def test_create_collection_raises_when_table_never_becomes_ready(hologres_module
hologres_module.redis_client.set.assert_not_called()
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch):
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch: pytest.MonkeyPatch):
factory = hologres_module.HologresVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -29,7 +29,7 @@ def _build_fake_elasticsearch_modules():
@pytest.fixture
def huawei_module(monkeypatch):
def huawei_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -155,7 +155,7 @@ def test_search_by_vector_and_full_text(huawei_module):
assert docs[0].page_content == "text-hit"
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch):
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch: pytest.MonkeyPatch):
class FakeDocument:
def __init__(self, page_content, vector, metadata):
self.page_content = page_content
@ -185,7 +185,7 @@ def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch
assert docs == []
def test_create_and_create_collection_paths(huawei_module, monkeypatch):
def test_create_and_create_collection_paths(huawei_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -218,7 +218,7 @@ def test_create_and_create_collection_paths(huawei_module, monkeypatch):
huawei_module.redis_client.set.assert_called_once()
def test_huawei_factory_branches(huawei_module, monkeypatch):
def test_huawei_factory_branches(huawei_module, monkeypatch: pytest.MonkeyPatch):
factory = huawei_module.HuaweiCloudVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -23,7 +23,7 @@ def _build_fake_iris_module():
@pytest.fixture
def iris_module(monkeypatch):
def iris_module(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
import dify_vdb_iris.iris_vector as module
@ -249,7 +249,7 @@ def test_iris_vector_init_get_cursor_and_create(iris_module):
vector._create_collection.assert_called_once_with(2)
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch: pytest.MonkeyPatch):
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", _config(iris_module))
@ -297,7 +297,7 @@ def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
assert docs[0].metadata["score"] == pytest.approx(0.9)
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch: pytest.MonkeyPatch):
cfg = _config(iris_module, IRIS_TEXT_INDEX=True)
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", cfg)
@ -344,7 +344,7 @@ def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
assert vector_like.search_by_full_text("100%", top_k=1) == []
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch):
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch: pytest.MonkeyPatch):
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True))

View File

@ -47,7 +47,7 @@ def _build_fake_opensearch_modules():
@pytest.fixture
def lindorm_module(monkeypatch):
def lindorm_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -100,7 +100,7 @@ def test_to_opensearch_params_and_init(lindorm_module):
assert vector_ugc._routing == "route"
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch):
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch: pytest.MonkeyPatch):
vector = lindorm_module.LindormVectorStore(
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
)
@ -301,7 +301,7 @@ def test_search_by_full_text_success_and_error(lindorm_module):
vector.search_by_full_text("hello")
def test_create_collection_paths(lindorm_module, monkeypatch):
def test_create_collection_paths(lindorm_module, monkeypatch: pytest.MonkeyPatch):
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
with pytest.raises(ValueError, match="cannot be empty"):
@ -331,7 +331,7 @@ def test_create_collection_paths(lindorm_module, monkeypatch):
vector._client.indices.create.assert_not_called()
def test_lindorm_factory_branches(lindorm_module, monkeypatch):
def test_lindorm_factory_branches(lindorm_module, monkeypatch: pytest.MonkeyPatch):
factory = lindorm_module.LindormVectorStoreFactory()
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200")

View File

@ -32,7 +32,7 @@ def _build_fake_mo_vector_modules():
@pytest.fixture
def matrixone_module(monkeypatch):
def matrixone_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_mo_vector_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -70,7 +70,7 @@ def test_matrixone_config_validation(matrixone_module, field, value, message):
matrixone_module.MatrixoneConfig.model_validate(values)
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch):
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -86,7 +86,7 @@ def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module,
matrixone_module.redis_client.set.assert_called_once()
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch):
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -146,7 +146,7 @@ def test_get_type_and_create_delegate_to_add_texts(matrixone_module):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch):
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -165,7 +165,7 @@ def test_get_client_handles_full_text_index_creation_error(matrixone_module, mon
matrixone_module.redis_client.set.assert_not_called()
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch):
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch: pytest.MonkeyPatch):
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
vector.client = MagicMock()
monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid")
@ -224,7 +224,7 @@ def test_search_by_vector_builds_documents(matrixone_module):
assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}}
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch):
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch: pytest.MonkeyPatch):
factory = matrixone_module.MatrixoneVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -99,7 +99,7 @@ def _build_fake_pymilvus_modules():
@pytest.fixture
def milvus_module(monkeypatch):
def milvus_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_pymilvus_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -327,7 +327,7 @@ def test_process_search_results_and_search_methods(milvus_module):
assert "document_id" in vector._client.search.call_args.kwargs["filter"]
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch):
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -351,7 +351,7 @@ def test_create_collection_cache_and_existing_collection(milvus_module, monkeypa
milvus_module.redis_client.set.assert_called()
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch):
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -385,7 +385,7 @@ def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch)
assert call_kwargs["consistency_level"] == "Session"
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch):
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch: pytest.MonkeyPatch):
factory = milvus_module.MilvusVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -38,7 +38,7 @@ def _build_fake_clickhouse_connect_module():
@pytest.fixture
def myscale_module(monkeypatch):
def myscale_module(monkeypatch: pytest.MonkeyPatch):
fake_module = _build_fake_clickhouse_connect_module()
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
@ -90,7 +90,7 @@ def test_delete_by_ids_short_circuits_on_empty_list(myscale_module):
vector._client.command.assert_not_called()
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch):
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch: pytest.MonkeyPatch):
factory = myscale_module.MyScaleVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
@ -160,7 +160,7 @@ def test_create_collection_builds_expected_sql(myscale_module):
assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch):
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch: pytest.MonkeyPatch):
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid")
docs = [

View File

@ -53,7 +53,7 @@ def _build_fake_pyobvector_module():
@pytest.fixture
def oceanbase_module(monkeypatch):
def oceanbase_module(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module())
import dify_vdb_oceanbase.oceanbase_vector as module
@ -208,7 +208,7 @@ def test_create_delegates_to_collection_and_insert(oceanbase_module):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch):
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -234,7 +234,7 @@ def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_mod
vector.delete.assert_not_called()
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch):
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -271,7 +271,7 @@ def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, mo
oceanbase_module.redis_client.set.assert_called_once()
def test_create_collection_error_paths(oceanbase_module, monkeypatch):
def test_create_collection_error_paths(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -308,7 +308,7 @@ def test_create_collection_error_paths(oceanbase_module, monkeypatch):
vector._create_collection()
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch):
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -517,7 +517,7 @@ def test_delete_success_and_exception(oceanbase_module):
vector.delete()
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch):
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
factory = oceanbase_module.OceanBaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -37,7 +37,7 @@ def _build_fake_psycopg2_modules():
@pytest.fixture
def opengauss_module(monkeypatch):
def opengauss_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_psycopg2_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -88,7 +88,7 @@ def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_modu
opengauss_module.OpenGaussConfig.model_validate(values)
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -99,7 +99,7 @@ def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
assert vector.pool is pool
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -126,7 +126,7 @@ def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
opengauss_module.redis_client.set.assert_called_once()
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch):
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -158,7 +158,7 @@ def test_search_by_vector_validates_top_k(opengauss_module):
vector.search_by_vector([0.1, 0.2], top_k=0)
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch):
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
@ -200,7 +200,7 @@ def test_create_calls_collection_insert_and_index(opengauss_module):
vector._create_index.assert_called_once_with(2)
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -220,7 +220,7 @@ def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
opengauss_module.redis_client.set.assert_not_called()
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch):
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -245,7 +245,7 @@ def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, m
assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql)
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch):
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
@ -342,7 +342,7 @@ def test_search_by_full_text_validates_top_k(opengauss_module):
vector.search_by_full_text("query", top_k=0)
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
lock = MagicMock()
@ -370,7 +370,7 @@ def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
opengauss_module.redis_client.set.assert_called_once()
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch):
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch: pytest.MonkeyPatch):
factory = opengauss_module.OpenGaussFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -59,7 +59,7 @@ def _build_fake_opensearch_modules():
@pytest.fixture
def opensearch_module(monkeypatch):
def opensearch_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -95,7 +95,7 @@ class TestOpenSearchConfig:
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] == ("admin", "password")
def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch):
def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch: pytest.MonkeyPatch):
class _Session:
def get_credentials(self):
return "creds"

View File

@ -58,7 +58,7 @@ def _build_fake_opensearch_modules():
@pytest.fixture
def opensearch_module(monkeypatch):
def opensearch_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -116,7 +116,7 @@ def test_config_validation_for_aws_auth_and_https_fields(opensearch_module):
opensearch_module.OpenSearchConfig.model_validate(values)
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch):
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch: pytest.MonkeyPatch):
class _Session:
def get_credentials(self):
return "creds"
@ -167,7 +167,7 @@ def test_init_and_create_delegate_calls(opensearch_module):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch):
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch: pytest.MonkeyPatch):
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es"))
docs = [
Document(page_content="a", metadata={"doc_id": "1"}),
@ -308,7 +308,7 @@ def test_search_by_full_text_and_filters(opensearch_module):
assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}]
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch):
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -331,7 +331,7 @@ def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch)
opensearch_module.redis_client.set.assert_called()
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch):
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch: pytest.MonkeyPatch):
factory = opensearch_module.OpenSearchVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -51,7 +51,7 @@ def _connection_with_cursor(cursor):
@pytest.fixture
def oracle_module(monkeypatch):
def oracle_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_oracle_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -94,7 +94,7 @@ def test_oracle_config_validation_autonomous_requirements(oracle_module):
)
def test_init_and_get_type(oracle_module, monkeypatch):
def test_init_and_get_type(oracle_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool))
vector = oracle_module.OracleVector("collection_1", _config(oracle_module))
@ -139,7 +139,7 @@ def test_numpy_converters_and_type_handlers(oracle_module):
assert out_float64.dtype == numpy.float64
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch):
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch: pytest.MonkeyPatch):
connect = MagicMock(return_value="connection")
monkeypatch.setattr(oracle_module.oracledb, "connect", connect)
@ -173,7 +173,7 @@ def test_create_delegates_collection_and_insert(oracle_module):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch):
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch: pytest.MonkeyPatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector.input_type_handler = MagicMock()
@ -279,7 +279,7 @@ def _fake_nltk_module(*, missing_data=False):
return nltk, nltk_corpus
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch):
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch: pytest.MonkeyPatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
@ -305,7 +305,7 @@ def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatc
assert "doc_id_0" in en_params
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch):
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch: pytest.MonkeyPatch):
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
vector.table_name = "embedding_collection_1"
vector._get_connection = MagicMock()
@ -320,7 +320,7 @@ def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeyp
vector.search_by_full_text("english query")
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -346,7 +346,9 @@ def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
oracle_module.redis_client.set.assert_called_once()
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch):
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(
oracle_module, monkeypatch: pytest.MonkeyPatch
):
factory = oracle_module.OracleVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -79,7 +79,7 @@ def _patch_both(monkeypatch, module, calls, execute_results=None):
@pytest.fixture
def pgvecto_module(monkeypatch):
def pgvecto_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_pgvecto_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -126,7 +126,7 @@ def test_collection_base_has_expected_annotations(pgvecto_module):
assert {"id", "text", "meta", "vector"} <= set(annotations)
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
@ -145,7 +145,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
session_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
@ -169,7 +169,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
module.redis_client.set.assert_called()
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
init_calls = []
runtime_calls = []
@ -241,7 +241,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
init_calls = []
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
@ -313,7 +313,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
assert vector.search_by_full_text("hello") == []
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch):
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
module, _ = pgvecto_module
factory = module.PGVectoRSFactory()
dataset_with_index = SimpleNamespace(

View File

@ -336,7 +336,7 @@ def test_create_delegates_collection_creation_and_insert():
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch):
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch: pytest.MonkeyPatch):
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
@ -387,7 +387,7 @@ def test_text_get_and_delete_methods():
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch):
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch: pytest.MonkeyPatch):
vector = PGVector.__new__(PGVector)
vector.table_name = "embedding_collection_1"
cursor = MagicMock()
@ -464,7 +464,7 @@ def test_search_by_full_text_branches_for_bigm_and_standard():
assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0]
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch):
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch: pytest.MonkeyPatch):
factory = pgvector_module.PGVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -121,7 +121,7 @@ def _build_fake_qdrant_modules():
@pytest.fixture
def qdrant_module(monkeypatch):
def qdrant_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_qdrant_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -170,7 +170,7 @@ def test_init_and_basic_behaviour(qdrant_module):
vector.add_texts.assert_called_once()
def test_create_collection_and_add_texts(qdrant_module, monkeypatch):
def test_create_collection_and_add_texts(qdrant_module, monkeypatch: pytest.MonkeyPatch):
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
lock = MagicMock()
lock.__enter__.return_value = None
@ -288,7 +288,7 @@ def test_search_and_helper_methods(qdrant_module):
assert doc.page_content == "doc"
def test_qdrant_factory_paths(qdrant_module, monkeypatch):
def test_qdrant_factory_paths(qdrant_module, monkeypatch: pytest.MonkeyPatch):
factory = qdrant_module.QdrantVectorFactory()
dataset = SimpleNamespace(
id="dataset-1",

View File

@ -59,7 +59,7 @@ def _patch_both(monkeypatch, module, session):
@pytest.fixture
def relyt_module(monkeypatch):
def relyt_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_relyt_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -97,7 +97,7 @@ def test_relyt_config_validation(relyt_module, field, value, message):
relyt_module.RelytConfig.model_validate(values)
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch: pytest.MonkeyPatch):
engine = MagicMock()
monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine))
vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1")
@ -114,7 +114,7 @@ def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -142,7 +142,7 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
relyt_module.redis_client.set.assert_called_once()
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch):
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector._group_id = "group-1"
@ -212,7 +212,7 @@ def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module):
# 3. delete_by_ids translates to uuids
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
@ -225,7 +225,7 @@ def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
# 4. text_exists True
def test_text_exists_true(relyt_module, monkeypatch):
def test_text_exists_true(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
@ -236,7 +236,7 @@ def test_text_exists_true(relyt_module, monkeypatch):
# 5. text_exists False
def test_text_exists_false(relyt_module, monkeypatch):
def test_text_exists_false(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
@ -284,7 +284,7 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
# 8. delete commits session
def test_delete_drops_table(relyt_module, monkeypatch):
def test_delete_drops_table(relyt_module, monkeypatch: pytest.MonkeyPatch):
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
vector._collection_name = "collection_1"
vector.client = MagicMock()
@ -295,7 +295,7 @@ def test_delete_drops_table(relyt_module, monkeypatch):
session.execute.assert_called_once()
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch: pytest.MonkeyPatch):
factory = relyt_module.RelytVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -77,7 +77,7 @@ def _build_fake_tablestore_module():
@pytest.fixture
def tablestore_module(monkeypatch):
def tablestore_module(monkeypatch: pytest.MonkeyPatch):
fake_module = _build_fake_tablestore_module()
monkeypatch.setitem(sys.modules, "tablestore", fake_module)
@ -177,7 +177,7 @@ def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module):
vector._delete_table_if_exist.assert_called_once()
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch):
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch: pytest.MonkeyPatch):
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
lock = MagicMock()
lock.__enter__.return_value = None
@ -289,7 +289,7 @@ def test_write_row_and_search_helpers(tablestore_module):
assert "score" not in docs[0].metadata
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch):
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch: pytest.MonkeyPatch):
factory = tablestore_module.TableStoreVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -136,7 +136,7 @@ def _build_fake_tencent_modules():
@pytest.fixture
def tencent_module(monkeypatch):
def tencent_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_tencent_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -187,7 +187,7 @@ def test_config_and_init_paths(tencent_module):
assert vector._enable_hybrid_search is False
def test_create_collection_branches(tencent_module, monkeypatch):
def test_create_collection_branches(tencent_module, monkeypatch: pytest.MonkeyPatch):
vector = tencent_module.TencentVector("collection_1", _config(tencent_module))
lock = MagicMock()
@ -279,7 +279,7 @@ def test_create_add_delete_and_search_behaviour(tencent_module):
vector._client.drop_collection.assert_called_once()
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch):
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch: pytest.MonkeyPatch):
factory = tencent_module.TencentVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -246,8 +246,18 @@ class TidbService:
userPrefix = item["userPrefix"]
if state == "ACTIVE" and len(userPrefix) > 0:
cluster_info = tidb_serverless_list_map[item["clusterId"]]
cluster_info.status = TidbAuthBindingStatus.ACTIVE
cluster_info.account = f"{userPrefix}.root"
if not cluster_info.qdrant_endpoint:
cluster_info.qdrant_endpoint = TidbService.extract_qdrant_endpoint(
item
) or TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"])
if cluster_info.qdrant_endpoint:
cluster_info.status = TidbAuthBindingStatus.ACTIVE
else:
logger.warning(
"Cluster %s is ACTIVE but qdrant endpoint is not ready; will retry later",
item["clusterId"],
)
db.session.add(cluster_info)
db.session.commit()
else:

View File

@ -1,8 +1,11 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from dify_vdb_tidb_on_qdrant.tidb_service import TidbService
from models.enums import TidbAuthBindingStatus
class TestExtractQdrantEndpoint:
"""Unit tests for TidbService.extract_qdrant_endpoint."""
@ -216,3 +219,86 @@ class TestBatchCreateEdgeCases:
private_key="priv",
region="us-east-1",
)
class TestBatchUpdateTidbServerlessClusterStatus:
"""Verify that status updates only expose clusters after qdrant endpoint is ready."""
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_sets_active_when_batch_response_contains_endpoint(self, mock_http, mock_db):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {
"clusters": [
{
"clusterId": "c-1",
"state": "ACTIVE",
"userPrefix": "pfx",
"endpoints": {"public": {"host": "gw.tidbcloud.com"}},
}
]
},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
assert binding.status == TidbAuthBindingStatus.ACTIVE
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com")
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_fetches_endpoint_when_batch_response_omits_it(self, mock_http, mock_db, mock_fetch_endpoint):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com"
assert binding.status == TidbAuthBindingStatus.ACTIVE
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()
@patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None)
@patch("dify_vdb_tidb_on_qdrant.tidb_service.db")
@patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client")
def test_keeps_creating_when_endpoint_is_not_ready(self, mock_http, mock_db, mock_fetch_endpoint):
binding = SimpleNamespace(
cluster_id="c-1",
status=TidbAuthBindingStatus.CREATING,
account="root",
qdrant_endpoint=None,
)
mock_http.get.return_value = MagicMock(
status_code=200,
json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]},
)
TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv")
assert binding.account == "pfx.root"
assert binding.qdrant_endpoint is None
assert binding.status == TidbAuthBindingStatus.CREATING
mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1")
mock_db.session.add.assert_called_once_with(binding)
mock_db.session.commit.assert_called_once()

View File

@ -46,7 +46,7 @@ def test_tidb_config_validation(tidb_module, field, value, message):
tidb_module.TiDBVectorConfig.model_validate(values)
def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
def test_init_get_type_and_distance_func(tidb_module, monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine"))
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2")
@ -63,7 +63,7 @@ def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch):
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch: pytest.MonkeyPatch):
fake_tidb_vector = types.ModuleType("tidb_vector")
fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy")
@ -107,7 +107,7 @@ def test_create_calls_collection_and_add_texts(tidb_module):
assert vector._dimension == 2
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -127,7 +127,7 @@ def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
tidb_module.redis_client.set.assert_not_called()
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch):
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -160,7 +160,7 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
tidb_module.redis_client.set.assert_called_once()
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch: pytest.MonkeyPatch):
class _InsertStmt:
def __init__(self, table):
self.table = table
@ -198,7 +198,7 @@ def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
@pytest.fixture
def tidb_vector_with_session(tidb_module, monkeypatch):
def tidb_vector_with_session(tidb_module, monkeypatch: pytest.MonkeyPatch):
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
vector._collection_name = "collection_1"
vector._engine = MagicMock()
@ -354,7 +354,7 @@ def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module):
# Test search_by_vector filters and scores
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch: pytest.MonkeyPatch):
session = MagicMock()
session.execute.return_value = [
('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2),
@ -392,7 +392,7 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
# Test delete drops table
def test_delete_drops_table(tidb_module, monkeypatch):
def test_delete_drops_table(tidb_module, monkeypatch: pytest.MonkeyPatch):
session = MagicMock()
session.execute.return_value = None
@ -413,7 +413,7 @@ def test_delete_drops_table(tidb_module, monkeypatch):
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch: pytest.MonkeyPatch):
factory = tidb_module.TiDBVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -36,7 +36,7 @@ def _build_fake_upstash_module():
@pytest.fixture
def upstash_module(monkeypatch):
def upstash_module(monkeypatch: pytest.MonkeyPatch):
# Remove patched modules if present
for modname in ["upstash_vector", "dify_vdb_upstash.upstash_vector"]:
if modname in sys.modules:
@ -65,7 +65,7 @@ def test_upstash_config_validation(upstash_module, field, value, message):
upstash_module.UpstashVectorConfig.model_validate(values)
def test_init_get_type_and_dimension(upstash_module, monkeypatch):
def test_init_get_type_and_dimension(upstash_module, monkeypatch: pytest.MonkeyPatch):
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
assert vector.get_type() == upstash_module.VectorType.UPSTASH
@ -162,7 +162,7 @@ def test_search_by_vector_filter_threshold_and_delete(upstash_module):
vector.index.reset.assert_called_once()
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch):
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch: pytest.MonkeyPatch):
factory = upstash_module.UpstashVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -37,7 +37,7 @@ def _build_fake_psycopg2_modules():
@pytest.fixture
def vastbase_module(monkeypatch):
def vastbase_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_psycopg2_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -93,7 +93,7 @@ def test_vastbase_config_rejects_invalid_connection_window(vastbase_module):
)
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch: pytest.MonkeyPatch):
pool = MagicMock()
monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
@ -114,7 +114,7 @@ def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
pool.putconn.assert_called_once_with(conn)
def test_create_and_add_texts(vastbase_module, monkeypatch):
def test_create_and_add_texts(vastbase_module, monkeypatch: pytest.MonkeyPatch):
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
vector.table_name = "embedding_collection_1"
vector._create_collection = MagicMock()
@ -205,7 +205,7 @@ def test_search_by_vector_and_full_text(vastbase_module):
assert full_docs[0].page_content == "full-text"
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch):
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -240,7 +240,7 @@ def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeyp
vastbase_module.redis_client.set.assert_called()
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch):
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch: pytest.MonkeyPatch):
factory = vastbase_module.VastbaseVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",

View File

@ -79,7 +79,7 @@ def _build_fake_vikingdb_modules():
@pytest.fixture
def vikingdb_module(monkeypatch):
def vikingdb_module(monkeypatch: pytest.MonkeyPatch):
for name, module in _build_fake_vikingdb_modules().items():
monkeypatch.setitem(sys.modules, name, module)
@ -117,7 +117,7 @@ def test_init_get_type_and_has_checks(vikingdb_module):
assert vector._has_index() is False
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch):
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch: pytest.MonkeyPatch):
lock = MagicMock()
lock.__enter__.return_value = None
lock.__exit__.return_value = None
@ -253,7 +253,7 @@ def test_delete_drops_index_and_collection_when_present(vikingdb_module):
vector._client.drop_collection.assert_not_called()
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch):
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch: pytest.MonkeyPatch):
factory = vikingdb_module.VikingDBVectorFactory()
dataset_with_index = SimpleNamespace(
id="dataset-1",
@ -293,7 +293,9 @@ def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, mo
("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"),
],
)
def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message):
def test_vikingdb_factory_raises_when_required_config_missing(
vikingdb_module, monkeypatch: pytest.MonkeyPatch, field, message
):
factory = vikingdb_module.VikingDBVectorFactory()
dataset = SimpleNamespace(
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None

View File

@ -22,7 +22,6 @@ dependencies = [
"redis[hiredis]>=7.4.0",
"sendgrid>=6.12.5",
"sseclient-py>=1.8.0",
# Stable: production-proven, cap below the next major
"aliyun-log-python-sdk>=0.9.44,<1.0.0",
"azure-identity>=1.25.3,<2.0.0",
@ -42,7 +41,6 @@ dependencies = [
"opentelemetry-propagator-b3>=1.41.1,<2.0.0",
"readabilipy>=0.3.0,<1.0.0",
"resend>=2.27.0,<3.0.0",
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]~=0.7.0",
"graphon~=0.2.2",
@ -98,11 +96,13 @@ dify-trace-mlflow = { workspace = true }
dify-trace-opik = { workspace = true }
dify-trace-tencent = { workspace = true }
dify-trace-weave = { workspace = true }
graphon = { git = "https://github.com/QuantumGhost/graphon", branch = "hitl-form-dev" }
[tool.uv]
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
package = false
override-dependencies = [
"litellm>=1.83.7",
"pyarrow>=18.0.0",
]
@ -174,7 +174,7 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.62.0",
"pyrefly>=0.64.0",
"xinference-client>=2.7.0",
]

Some files were not shown because too many files have changed in this diff Show More