mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
Merge branch 'main' into 4-27-app-deploy
This commit is contained in:
commit
1cada0c49c
2
.github/workflows/api-tests.yml
vendored
2
.github/workflows/api-tests.yml
vendored
@ -99,7 +99,7 @@ jobs:
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
6
.github/workflows/autofix.yml
vendored
6
.github/workflows/autofix.yml
vendored
@ -116,6 +116,12 @@ jobs:
|
||||
if: github.event_name != 'merge_group'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Generate API docs
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd api
|
||||
uv run dev/generate_swagger_markdown_docs.py --swagger-dir openapi --markdown-dir openapi/markdown
|
||||
|
||||
- name: ESLint autofix
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@ -37,7 +37,7 @@ jobs:
|
||||
- name: Prepare middleware env
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
cp envs/middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
@ -87,7 +87,7 @@ jobs:
|
||||
- name: Prepare middleware env for MySQL
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
cp envs/middleware.env.example middleware.env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
|
||||
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
|
||||
|
||||
8
.github/workflows/main-ci.yml
vendored
8
.github/workflows/main-ci.yml
vendored
@ -57,7 +57,7 @@ jobs:
|
||||
- '.github/workflows/api-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
- 'docker/generate_docker_compose'
|
||||
@ -84,7 +84,7 @@ jobs:
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- '.github/workflows/web-e2e.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
vdb:
|
||||
@ -94,7 +94,7 @@ jobs:
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
- 'docker/generate_docker_compose'
|
||||
@ -116,7 +116,7 @@ jobs:
|
||||
- '.github/workflows/db-migration-test.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
- 'docker/generate_docker_compose'
|
||||
|
||||
22
.github/workflows/pyrefly-diff-comment.yml
vendored
22
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -77,10 +77,28 @@ jobs:
|
||||
}
|
||||
|
||||
if (diff.trim()) {
|
||||
await github.rest.issues.createComment({
|
||||
const body = '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>';
|
||||
const marker = '### Pyrefly Diff';
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>',
|
||||
});
|
||||
const existing = comments.find((comment) => comment.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
21
.github/workflows/pyrefly-diff.yml
vendored
21
.github/workflows/pyrefly-diff.yml
vendored
@ -103,9 +103,26 @@ jobs:
|
||||
].join('\n')
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
const marker = '### Pyrefly Diff';
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
const existing = comments.find((comment) => comment.body.startsWith(marker));
|
||||
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: existing.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
}
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -51,7 +51,7 @@ jobs:
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -48,7 +48,7 @@ jobs:
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
4
Makefile
4
Makefile
@ -71,13 +71,13 @@ type-check:
|
||||
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@./dev/pyrefly-check-local
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
type-check-core:
|
||||
@echo "📝 Running core type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Core type checks complete"
|
||||
|
||||
test:
|
||||
|
||||
@ -76,11 +76,10 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
./dify-compose up -d
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory.
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
|
||||
|
||||
#### Seeking help
|
||||
@ -138,7 +137,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
If you need to customize the configuration, edit `docker/.env`. The essential startup defaults live in [`docker/.env.example`](docker/.env.example), and optional advanced variables are split under `docker/envs/` by theme. After making any changes, re-run `docker compose up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
|
||||
@ -98,6 +98,8 @@ DB_DATABASE=dify
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
# Connection pool reset behavior on return
|
||||
SQLALCHEMY_POOL_RESET_ON_RETURN=rollback
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
@ -381,7 +383,7 @@ VIKINGDB_ACCESS_KEY=your-ak
|
||||
VIKINGDB_SECRET_KEY=your-sk
|
||||
VIKINGDB_REGION=cn-shanghai
|
||||
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
|
||||
VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_SCHEME=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
@ -432,8 +434,6 @@ UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
# Mail configuration, support: resend, smtp, sendgrid
|
||||
|
||||
@ -114,7 +114,7 @@ class SQLAlchemyEngineOptionsDict(TypedDict):
|
||||
pool_pre_ping: bool
|
||||
connect_args: dict[str, str]
|
||||
pool_use_lifo: bool
|
||||
pool_reset_on_return: None
|
||||
pool_reset_on_return: Literal["commit", "rollback", None]
|
||||
pool_timeout: int
|
||||
|
||||
|
||||
@ -223,6 +223,11 @@ class DatabaseConfig(BaseSettings):
|
||||
default=30,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_RESET_ON_RETURN: Literal["commit", "rollback", None] = Field(
|
||||
description="Connection pool reset behavior on return. Options: 'commit', 'rollback', or None",
|
||||
default="rollback",
|
||||
)
|
||||
|
||||
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
|
||||
description="Number of processes for the retrieval service, default to CPU cores.",
|
||||
default=os.cpu_count() or 1,
|
||||
@ -252,7 +257,7 @@ class DatabaseConfig(BaseSettings):
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": connect_args,
|
||||
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
|
||||
"pool_reset_on_return": None,
|
||||
"pool_reset_on_return": self.SQLALCHEMY_POOL_RESET_ON_RETURN,
|
||||
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
|
||||
}
|
||||
return result
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
"name": "Website Generator"
|
||||
},
|
||||
"app_id": "b53545b1-79ea-4da3-b31a-c39391c6f041",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": null,
|
||||
"description": null,
|
||||
"is_listed": true,
|
||||
@ -35,7 +35,7 @@
|
||||
"name": "Investment Analysis Report Copilot"
|
||||
},
|
||||
"app_id": "a23b57fa-85da-49c0-a571-3aff375976c1",
|
||||
"category": "Agent",
|
||||
"categories": ["Agent"],
|
||||
"copyright": "Dify.AI",
|
||||
"description": "Welcome to your personalized Investment Analysis Copilot service, where we delve into the depths of stock analysis to provide you with comprehensive insights. \n",
|
||||
"is_listed": true,
|
||||
@ -51,7 +51,7 @@
|
||||
"name": "Workflow Planning Assistant "
|
||||
},
|
||||
"app_id": "f3303a7d-a81c-404e-b401-1f8711c998c1",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "An assistant that helps you plan and select the right node for a workflow (V0.6.0). ",
|
||||
"is_listed": true,
|
||||
@ -67,7 +67,7 @@
|
||||
"name": "Automated Email Reply "
|
||||
},
|
||||
"app_id": "e9d92058-7d20-4904-892f-75d90bef7587",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Reply emails using Gmail API. It will automatically retrieve email in your inbox and create a response in Gmail. \nConfigure your Gmail API in Google Cloud Console. ",
|
||||
"is_listed": true,
|
||||
@ -83,7 +83,7 @@
|
||||
"name": "Book Translation "
|
||||
},
|
||||
"app_id": "98b87f88-bd22-4d86-8b74-86beba5e0ed4",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "A workflow designed to translate a full book up to 15000 tokens per run. Uses Code node to separate text into chunks and Iteration to translate each chunk. ",
|
||||
"is_listed": true,
|
||||
@ -99,7 +99,7 @@
|
||||
"name": "Python bug fixer"
|
||||
},
|
||||
"app_id": "cae337e6-aec5-4c7b-beca-d6f1a808bd5e",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": null,
|
||||
"description": null,
|
||||
"is_listed": true,
|
||||
@ -115,7 +115,7 @@
|
||||
"name": "Code Interpreter"
|
||||
},
|
||||
"app_id": "d077d587-b072-4f2c-b631-69ed1e7cdc0f",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "Code interpreter, clarifying the syntax and semantics of the code.",
|
||||
"is_listed": true,
|
||||
@ -131,7 +131,7 @@
|
||||
"name": "SVG Logo Design "
|
||||
},
|
||||
"app_id": "73fbb5f1-c15d-4d74-9cc8-46d9db9b2cca",
|
||||
"category": "Agent",
|
||||
"categories": ["Agent"],
|
||||
"copyright": "Dify.AI",
|
||||
"description": "Hello, I am your creative partner in bringing ideas to vivid life! I can assist you in creating stunning designs by leveraging abilities of DALL·E 3. ",
|
||||
"is_listed": true,
|
||||
@ -147,7 +147,7 @@
|
||||
"name": "Long Story Generator (Iteration) "
|
||||
},
|
||||
"app_id": "5efb98d7-176b-419c-b6ef-50767391ab62",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "A workflow demonstrating how to use Iteration node to generate long article that is longer than the context length of LLMs. ",
|
||||
"is_listed": true,
|
||||
@ -163,7 +163,7 @@
|
||||
"name": "Text Summarization Workflow"
|
||||
},
|
||||
"app_id": "f00c4531-6551-45ee-808f-1d7903099515",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Based on users' choice, retrieve external knowledge to more accurately summarize articles.",
|
||||
"is_listed": true,
|
||||
@ -179,7 +179,7 @@
|
||||
"name": "YouTube Channel Data Analysis"
|
||||
},
|
||||
"app_id": "be591209-2ca8-410f-8f3b-ca0e530dd638",
|
||||
"category": "Agent",
|
||||
"categories": ["Agent"],
|
||||
"copyright": "Dify.AI",
|
||||
"description": "I am a YouTube Channel Data Analysis Copilot, I am here to provide expert data analysis tailored to your needs. ",
|
||||
"is_listed": true,
|
||||
@ -195,7 +195,7 @@
|
||||
"name": "Article Grading Bot"
|
||||
},
|
||||
"app_id": "a747f7b4-c48b-40d6-b313-5e628232c05f",
|
||||
"category": "Writing",
|
||||
"categories": ["Writing"],
|
||||
"copyright": null,
|
||||
"description": "Assess the quality of articles and text based on user defined criteria. ",
|
||||
"is_listed": true,
|
||||
@ -211,7 +211,7 @@
|
||||
"name": "SEO Blog Generator"
|
||||
},
|
||||
"app_id": "18f3bd03-524d-4d7a-8374-b30dbe7c69d5",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Workflow for retrieving information from the internet, followed by segmented generation of SEO blogs.",
|
||||
"is_listed": true,
|
||||
@ -227,7 +227,7 @@
|
||||
"name": "SQL Creator"
|
||||
},
|
||||
"app_id": "050ef42e-3e0c-40c1-a6b6-a64f2c49d744",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "Write SQL from natural language by pasting in your schema with the request.Please describe your query requirements in natural language and select the target database type.",
|
||||
"is_listed": true,
|
||||
@ -243,7 +243,7 @@
|
||||
"name": "Sentiment Analysis "
|
||||
},
|
||||
"app_id": "f06bf86b-d50c-4895-a942-35112dbe4189",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Batch sentiment analysis of text, followed by JSON output of sentiment classification along with scores.",
|
||||
"is_listed": true,
|
||||
@ -259,7 +259,7 @@
|
||||
"name": "Strategic Consulting Expert"
|
||||
},
|
||||
"app_id": "7e8ca1ae-02f2-4b5f-979e-62d19133bee2",
|
||||
"category": "Assistant",
|
||||
"categories": ["Assistant"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "I can answer your questions related to strategic marketing.",
|
||||
"is_listed": true,
|
||||
@ -275,7 +275,7 @@
|
||||
"name": "Code Converter"
|
||||
},
|
||||
"app_id": "4006c4b2-0735-4f37-8dbb-fb1a8c5bd87a",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "This is an application that provides the ability to convert code snippets in multiple programming languages. You can input the code you wish to convert, select the target programming language, and get the desired output.",
|
||||
"is_listed": true,
|
||||
@ -291,7 +291,7 @@
|
||||
"name": "Question Classifier + Knowledge + Chatbot "
|
||||
},
|
||||
"app_id": "d9f6b733-e35d-4a40-9f38-ca7bbfa009f7",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Basic Workflow Template, a chatbot capable of identifying intents alongside with a knowledge base.",
|
||||
"is_listed": true,
|
||||
@ -307,7 +307,7 @@
|
||||
"name": "AI Front-end interviewer"
|
||||
},
|
||||
"app_id": "127efead-8944-4e20-ba9d-12402eb345e0",
|
||||
"category": "HR",
|
||||
"categories": ["HR"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "A simulated front-end interviewer that tests the skill level of front-end development through questioning.",
|
||||
"is_listed": true,
|
||||
@ -323,7 +323,7 @@
|
||||
"name": "Knowledge Retrieval + Chatbot "
|
||||
},
|
||||
"app_id": "e9870913-dd01-4710-9f06-15d4180ca1ce",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Basic Workflow Template, A chatbot with a knowledge base. ",
|
||||
"is_listed": true,
|
||||
@ -339,7 +339,7 @@
|
||||
"name": "Email Assistant Workflow "
|
||||
},
|
||||
"app_id": "dd5b6353-ae9b-4bce-be6a-a681a12cf709",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "A multifunctional email assistant capable of summarizing, replying, composing, proofreading, and checking grammar.",
|
||||
"is_listed": true,
|
||||
@ -355,7 +355,7 @@
|
||||
"name": "Customer Review Analysis Workflow "
|
||||
},
|
||||
"app_id": "9c0cd31f-4b62-4005-adf5-e3888d08654a",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Utilize LLM (Large Language Models) to classify customer reviews and forward them to the internal system.",
|
||||
"is_listed": true,
|
||||
|
||||
@ -1,6 +1,14 @@
|
||||
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
|
||||
"""Helpers for registering Pydantic models with Flask-RESTX namespaces.
|
||||
|
||||
Flask-RESTX treats `SchemaModel` bodies as opaque JSON schemas; it does not
|
||||
promote Pydantic's nested `$defs` into top-level Swagger `definitions`.
|
||||
These helpers keep that translation centralized so models registered through
|
||||
`register_schema_models` emit resolvable Swagger 2.0 references.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
@ -8,10 +16,52 @@ from pydantic import BaseModel, TypeAdapter
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a single BaseModel with a namespace for Swagger documentation."""
|
||||
QueryParamDoc = TypedDict(
|
||||
"QueryParamDoc",
|
||||
{
|
||||
"in": NotRequired[str],
|
||||
"type": NotRequired[str],
|
||||
"items": NotRequired[dict[str, object]],
|
||||
"required": NotRequired[bool],
|
||||
"description": NotRequired[str],
|
||||
"enum": NotRequired[list[object]],
|
||||
"default": NotRequired[object],
|
||||
"minimum": NotRequired[int | float],
|
||||
"maximum": NotRequired[int | float],
|
||||
"minLength": NotRequired[int],
|
||||
"maxLength": NotRequired[int],
|
||||
"minItems": NotRequired[int],
|
||||
"maxItems": NotRequired[int],
|
||||
},
|
||||
)
|
||||
|
||||
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
|
||||
nested_definitions = schema.get("$defs")
|
||||
schema_to_register = dict(schema)
|
||||
if isinstance(nested_definitions, dict):
|
||||
schema_to_register.pop("$defs")
|
||||
|
||||
namespace.schema_model(name, schema_to_register)
|
||||
|
||||
if not isinstance(nested_definitions, dict):
|
||||
return
|
||||
|
||||
for nested_name, nested_schema in nested_definitions.items():
|
||||
if isinstance(nested_schema, dict):
|
||||
_register_json_schema(namespace, nested_name, nested_schema)
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a BaseModel and its nested schema definitions for Swagger documentation."""
|
||||
|
||||
_register_json_schema(
|
||||
namespace,
|
||||
model.__name__,
|
||||
model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
|
||||
@ -34,14 +84,111 @@ def get_or_create_model(model_name: str, field_def):
|
||||
def register_enum_models(namespace: Namespace, *models: type[StrEnum]) -> None:
|
||||
"""Register multiple StrEnum with a namespace."""
|
||||
for model in models:
|
||||
namespace.schema_model(
|
||||
model.__name__, TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
_register_json_schema(
|
||||
namespace,
|
||||
model.__name__,
|
||||
TypeAdapter(model).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
|
||||
"""Build Flask-RESTX query parameter docs from a flat Pydantic model.
|
||||
|
||||
`Namespace.expect()` treats Pydantic schema models as request bodies, so GET
|
||||
endpoints should keep runtime validation on the Pydantic model and feed this
|
||||
derived mapping to `Namespace.doc(params=...)` for Swagger documentation.
|
||||
"""
|
||||
|
||||
schema = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
properties = schema.get("properties", {})
|
||||
if not isinstance(properties, Mapping):
|
||||
return {}
|
||||
|
||||
required = schema.get("required", [])
|
||||
required_names = set(required) if isinstance(required, list) else set()
|
||||
|
||||
params: dict[str, QueryParamDoc] = {}
|
||||
for name, property_schema in properties.items():
|
||||
if not isinstance(name, str) or not isinstance(property_schema, Mapping):
|
||||
continue
|
||||
|
||||
params[name] = _query_param_from_property(property_schema, required=name in required_names)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
|
||||
param_schema = _nullable_property_schema(property_schema)
|
||||
param_doc: QueryParamDoc = {"in": "query", "required": required}
|
||||
|
||||
description = param_schema.get("description")
|
||||
if isinstance(description, str):
|
||||
param_doc["description"] = description
|
||||
|
||||
schema_type = param_schema.get("type")
|
||||
if isinstance(schema_type, str) and schema_type in {"array", "boolean", "integer", "number", "string"}:
|
||||
param_doc["type"] = schema_type
|
||||
if schema_type == "array":
|
||||
items = param_schema.get("items")
|
||||
if isinstance(items, Mapping):
|
||||
item_type = items.get("type")
|
||||
if isinstance(item_type, str):
|
||||
param_doc["items"] = {"type": item_type}
|
||||
|
||||
enum = param_schema.get("enum")
|
||||
if isinstance(enum, list):
|
||||
param_doc["enum"] = enum
|
||||
|
||||
default = param_schema.get("default")
|
||||
if default is not None:
|
||||
param_doc["default"] = default
|
||||
|
||||
minimum = param_schema.get("minimum")
|
||||
if isinstance(minimum, int | float):
|
||||
param_doc["minimum"] = minimum
|
||||
|
||||
maximum = param_schema.get("maximum")
|
||||
if isinstance(maximum, int | float):
|
||||
param_doc["maximum"] = maximum
|
||||
|
||||
min_length = param_schema.get("minLength")
|
||||
if isinstance(min_length, int):
|
||||
param_doc["minLength"] = min_length
|
||||
|
||||
max_length = param_schema.get("maxLength")
|
||||
if isinstance(max_length, int):
|
||||
param_doc["maxLength"] = max_length
|
||||
|
||||
min_items = param_schema.get("minItems")
|
||||
if isinstance(min_items, int):
|
||||
param_doc["minItems"] = min_items
|
||||
|
||||
max_items = param_schema.get("maxItems")
|
||||
if isinstance(max_items, int):
|
||||
param_doc["maxItems"] = max_items
|
||||
|
||||
return param_doc
|
||||
|
||||
|
||||
def _nullable_property_schema(property_schema: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
any_of = property_schema.get("anyOf")
|
||||
if not isinstance(any_of, list):
|
||||
return property_schema
|
||||
|
||||
non_null_candidates = [
|
||||
candidate for candidate in any_of if isinstance(candidate, Mapping) and candidate.get("type") != "null"
|
||||
]
|
||||
|
||||
if len(non_null_candidates) == 1:
|
||||
return {**property_schema, **non_null_candidates[0]}
|
||||
|
||||
return property_schema
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||
"get_or_create_model",
|
||||
"query_params_from_model",
|
||||
"register_enum_models",
|
||||
"register_schema_model",
|
||||
"register_schema_models",
|
||||
|
||||
@ -12,6 +12,7 @@ from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
@ -301,15 +302,7 @@ class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
register_schema_models(console_ns, UpsertNotificationPayload, BatchAddNotificationAccountsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
|
||||
@ -25,6 +25,7 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -841,7 +842,8 @@ class AppTraceApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
"""Get app trace"""
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id)
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id, session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -33,6 +33,7 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
register_enum_models(console_ns, ImportStatus)
|
||||
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Sequence
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -19,13 +20,12 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
@ -41,16 +41,16 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ModelConfig)
|
||||
register_enum_models(console_ns, LLMMode)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
RuleGeneratePayload,
|
||||
RuleCodeGeneratePayload,
|
||||
RuleStructuredOutputPayload,
|
||||
InstructionGeneratePayload,
|
||||
InstructionTemplatePayload,
|
||||
ModelConfig,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
|
||||
@ -5,7 +5,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.base import ResponseModel
|
||||
@ -15,7 +15,7 @@ from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
class RecommendedAppsQuery(BaseModel):
|
||||
language: str | None = Field(default=None)
|
||||
language: str | None = Field(default=None, description="Language code for recommended app localization")
|
||||
|
||||
|
||||
class RecommendedAppInfoResponse(ResponseModel):
|
||||
@ -52,7 +52,7 @@ class RecommendedAppResponse(ResponseModel):
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
category: str | None = None
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
position: int | None = None
|
||||
is_listed: bool | None = None
|
||||
can_trial: bool | None = None
|
||||
@ -74,7 +74,7 @@ register_schema_models(
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
|
||||
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -876,10 +876,10 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
|
||||
tenant_id=current_tenant_id, provider=provider, id=payload.id
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -842,24 +842,24 @@ class WorkflowResponseConverter:
|
||||
return []
|
||||
|
||||
files: list[Mapping[str, Any]] = []
|
||||
if isinstance(value, FileSegment):
|
||||
files.append(value.value.to_dict())
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files.extend([i.to_dict() for i in value.value])
|
||||
elif isinstance(value, File):
|
||||
files.append(value.to_dict())
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
file = cls._get_file_var_from_value(item)
|
||||
match value:
|
||||
case FileSegment():
|
||||
files.append(value.value.to_dict())
|
||||
case ArrayFileSegment():
|
||||
files.extend([i.to_dict() for i in value.value])
|
||||
case File():
|
||||
files.append(value.to_dict())
|
||||
case list():
|
||||
for item in value:
|
||||
file = cls._get_file_var_from_value(item)
|
||||
if file:
|
||||
files.append(file)
|
||||
case dict():
|
||||
file = cls._get_file_var_from_value(value)
|
||||
if file:
|
||||
files.append(file)
|
||||
elif isinstance(
|
||||
value,
|
||||
dict,
|
||||
):
|
||||
file = cls._get_file_var_from_value(value)
|
||||
if file:
|
||||
files.append(file)
|
||||
case _:
|
||||
pass
|
||||
|
||||
return files
|
||||
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
"""LLM-related application services."""
|
||||
|
||||
from .quota import deduct_llm_quota, ensure_llm_quota_available
|
||||
from .quota import (
|
||||
deduct_llm_quota,
|
||||
deduct_llm_quota_for_model,
|
||||
ensure_llm_quota_available,
|
||||
ensure_llm_quota_available_for_model,
|
||||
)
|
||||
|
||||
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
|
||||
__all__ = [
|
||||
"deduct_llm_quota",
|
||||
"deduct_llm_quota_for_model",
|
||||
"ensure_llm_quota_available",
|
||||
"ensure_llm_quota_available_for_model",
|
||||
]
|
||||
|
||||
@ -1,4 +1,14 @@
|
||||
from sqlalchemy import update
|
||||
"""Tenant-scoped helpers for checking and deducting LLM provider quota.
|
||||
|
||||
System-hosted quota accounting is currently defined only for LLM models. Keep
|
||||
the public helpers LLM-specific so callers do not carry unused model-type
|
||||
plumbing, and fail loudly if the deprecated ``ModelInstance`` wrappers are used
|
||||
with a non-LLM model.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -6,44 +16,47 @@ from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
def _get_provider_configuration(*, tenant_id: str, provider: str):
|
||||
"""Resolve the tenant-bound provider configuration for quota decisions."""
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_configuration = provider_manager.get_configurations(tenant_id).get(provider)
|
||||
if provider_configuration is None:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
return provider_configuration
|
||||
|
||||
|
||||
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
|
||||
"""Raise when a tenant-bound LLM model is already out of quota."""
|
||||
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
model_type=ModelType.LLM,
|
||||
model=model,
|
||||
)
|
||||
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
|
||||
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None:
|
||||
"""Compute the quota impact for an LLM invocation under the current quota mode."""
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
return None
|
||||
|
||||
break
|
||||
|
||||
@ -52,42 +65,136 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
used_quota = dify_config.get_model_credits(model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
return used_quota
|
||||
|
||||
|
||||
def _deduct_free_llm_quota(
|
||||
*,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
quota_type: ProviderQuotaType,
|
||||
used_quota: int,
|
||||
) -> None:
|
||||
"""Deduct FREE provider quota, capping at the limit before reporting exhaustion."""
|
||||
quota_exceeded = False
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
provider_record = session.scalar(
|
||||
select(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == quota_type,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
if (
|
||||
provider_record is None
|
||||
or provider_record.quota_limit is None
|
||||
or provider_record.quota_used is None
|
||||
or provider_record.quota_limit <= provider_record.quota_used
|
||||
):
|
||||
quota_exceeded = True
|
||||
else:
|
||||
available_quota = provider_record.quota_limit - provider_record.quota_used
|
||||
deducted_quota = min(used_quota, available_quota)
|
||||
provider_record.quota_used += deducted_quota
|
||||
provider_record.last_used = naive_utc_now()
|
||||
quota_exceeded = deducted_quota < used_quota
|
||||
|
||||
if quota_exceeded:
|
||||
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
|
||||
|
||||
|
||||
def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None:
|
||||
"""Apply a resolved LLM quota charge against the current provider quota bucket."""
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
match system_configuration.current_quota_type:
|
||||
case ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
case ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
case ProviderQuotaType.FREE:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
_deduct_free_llm_quota(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
quota_type=system_configuration.current_quota_type,
|
||||
used_quota=used_quota,
|
||||
)
|
||||
case _:
|
||||
return
|
||||
|
||||
|
||||
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
|
||||
"""Deduct tenant-bound quota for the resolved LLM model identity."""
|
||||
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
|
||||
used_quota = _resolve_llm_used_quota(
|
||||
system_configuration=provider_configuration.system_configuration,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
_deduct_used_llm_quota(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
provider_configuration=provider_configuration,
|
||||
used_quota=used_quota,
|
||||
)
|
||||
|
||||
|
||||
def _require_llm_model_instance(model_instance: ModelInstance) -> None:
|
||||
"""Reject deprecated wrapper calls that pass a non-LLM model instance."""
|
||||
if model_instance.model_type_instance.model_type != ModelType.LLM:
|
||||
raise ValueError("LLM quota helpers only support LLM model instances.")
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
|
||||
warnings.warn(
|
||||
"ensure_llm_quota_available(model_instance=...) is deprecated; "
|
||||
"use ensure_llm_quota_available_for_model(...) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_require_llm_model_instance(model_instance)
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id=model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
|
||||
warnings.warn(
|
||||
"deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; "
|
||||
"use deduct_llm_quota_for_model(...) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_require_llm_model_instance(model_instance)
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_instance.provider,
|
||||
model=model_instance.model_name,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@ -1,36 +1,48 @@
|
||||
"""
|
||||
LLM quota deduction layer for GraphEngine.
|
||||
|
||||
This layer centralizes model-quota deduction outside node implementations.
|
||||
This layer centralizes model-quota handling outside node implementations.
|
||||
|
||||
Graphon LLM-backed nodes expose provider/model identity through public node
|
||||
configuration and, after execution, through ``node_run_result.inputs``. Resolve
|
||||
quota billing from that public identity instead of depending on
|
||||
``ModelInstance`` reconstruction inside the workflow layer. Missing identity on
|
||||
quota-tracked nodes is treated as a workflow bug and aborts execution so quota
|
||||
handling is never silently skipped.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final, override
|
||||
from typing import final, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_QUOTA_NODE_TYPES = frozenset(
|
||||
[
|
||||
BuiltinNodeTypes.LLM,
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class LLMQuotaLayer(GraphEngineLayer):
|
||||
"""Graph layer that applies LLM quota deduction after node execution."""
|
||||
"""Graph layer that applies tenant-scoped quota checks to LLM-backed nodes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
tenant_id: str
|
||||
_abort_sent: bool
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
super().__init__()
|
||||
self.tenant_id = tenant_id
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
@ -50,33 +62,49 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if self._abort_sent:
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
if not self._supports_quota(node):
|
||||
return
|
||||
|
||||
model_identity = self._extract_model_identity_from_node(node)
|
||||
if model_identity is None:
|
||||
reason = "LLM quota check requires public node model identity before execution."
|
||||
self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError")
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
return
|
||||
|
||||
provider, model_name = model_identity
|
||||
try:
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
ensure_llm_quota_available_for_model(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__)
|
||||
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||
|
||||
@override
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent) or not self._supports_quota(node):
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
model_identity = self._extract_model_identity_from_result_event(result_event)
|
||||
if model_identity is None:
|
||||
self._abort_for_missing_model_identity(
|
||||
node=node,
|
||||
reason="LLM quota deduction requires model identity in the node result event.",
|
||||
)
|
||||
return
|
||||
|
||||
provider, model_name = model_identity
|
||||
|
||||
try:
|
||||
dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
deduct_llm_quota(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
@ -92,6 +120,27 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
|
||||
def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
node.node_data.error_strategy = None
|
||||
node.node_data.retry_config.retry_enabled = False
|
||||
|
||||
def quota_aborted_run() -> NodeRunResult:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=reason,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
# TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override.
|
||||
node._run = quota_aborted_run # type: ignore[method-assign]
|
||||
self._send_abort_command(reason=reason)
|
||||
|
||||
def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=reason)
|
||||
logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason)
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
@ -108,29 +157,38 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
logger.exception("Failed to send quota abort command")
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
model_instance = cast("LLMNode", node).model_instance
|
||||
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
|
||||
model_instance = cast("ParameterExtractorNode", node).model_instance
|
||||
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
|
||||
model_instance = cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
def _supports_quota(node: Node) -> bool:
|
||||
return node.node_type in _QUOTA_NODE_TYPES
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_identity_from_result_event(result_event: NodeRunSucceededEvent) -> tuple[str, str] | None:
|
||||
provider = result_event.node_run_result.inputs.get("model_provider")
|
||||
model_name = result_event.node_run_result.inputs.get("model_name")
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
return provider, model_name
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_identity_from_node(node: Node) -> tuple[str, str] | None:
|
||||
node_data = getattr(node, "node_data", None)
|
||||
if node_data is None:
|
||||
node_data = getattr(node, "data", None)
|
||||
|
||||
model_config = getattr(node_data, "model", None)
|
||||
if model_config is None:
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
"LLMQuotaLayer skipped quota handling because node model config is missing, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
if isinstance(model_instance, ModelInstance):
|
||||
return model_instance
|
||||
|
||||
raw_model_instance = getattr(model_instance, "_model_instance", None)
|
||||
if isinstance(raw_model_instance, ModelInstance):
|
||||
return raw_model_instance
|
||||
provider = getattr(model_config, "provider", None)
|
||||
model_name = getattr(model_config, "name", None)
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
return provider, model_name
|
||||
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota handling because node model identity is invalid, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
|
||||
@ -23,7 +23,7 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_assembly
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
@ -33,7 +33,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.enums import CredentialSourceType
|
||||
@ -106,11 +106,18 @@ class ProviderConfiguration(BaseModel):
|
||||
"""Attach the already-composed runtime for request-bound call chains."""
|
||||
self._bound_model_runtime = model_runtime
|
||||
|
||||
def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]:
|
||||
"""Resolve a provider factory that stays aligned with the runtime used by the caller."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return self._bound_model_runtime, ModelProviderFactory(runtime=self._bound_model_runtime)
|
||||
|
||||
model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id)
|
||||
return model_assembly.model_runtime, model_assembly.model_provider_factory
|
||||
|
||||
def get_model_provider_factory(self) -> ModelProviderFactory:
|
||||
"""Return a provider factory that preserves any request-bound runtime."""
|
||||
if self._bound_model_runtime is not None:
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
_, model_provider_factory = self._get_runtime_and_provider_factory()
|
||||
return model_provider_factory
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
@ -1392,10 +1399,13 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
model_runtime, model_provider_factory = self._get_runtime_and_provider_factory()
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider)
|
||||
return create_model_type_instance(
|
||||
runtime=model_runtime,
|
||||
provider_schema=provider_schema,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
@ -41,10 +41,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
|
||||
|
||||
# Get model instance of LLM
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(
|
||||
model_assembly = create_plugin_model_assembly(tenant_id=tenant_id)
|
||||
model_type_instance = model_assembly.create_model_type_instance(
|
||||
provider=openai_provider_name, model_type=ModelType.MODERATION
|
||||
)
|
||||
model_type_instance = cast(ModerationModel, model_type_instance)
|
||||
|
||||
@ -569,13 +569,13 @@ class OpsTraceManager:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_app_tracing_config(cls, app_id: str):
|
||||
def get_app_tracing_config(cls, app_id: str, session: Session):
|
||||
"""
|
||||
Get app tracing config
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
app: App | None = db.session.get(App, app_id)
|
||||
app: App | None = session.get(App, app_id)
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not app.tracing:
|
||||
|
||||
@ -4,23 +4,32 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Union
|
||||
from typing import IO, Any, Literal, cast, overload
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.plugin.impl.asset import PluginAssetManager
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkWithStructuredOutput,
|
||||
LLMResultWithStructuredOutput,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -29,6 +38,68 @@ logger = logging.getLogger(__name__)
|
||||
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
|
||||
|
||||
|
||||
# TODO(-LAN-): Move native structured-output invocation into Graphon's LLM node.
|
||||
# TODO(-LAN-): Remove this Dify-side adapter once Graphon owns structured output end-to-end.
|
||||
class _PluginStructuredOutputModelInstance:
|
||||
"""Bind plugin model identity to the shared structured-output helper.
|
||||
|
||||
The structured-output parser is shared with legacy ``ModelInstance`` flows
|
||||
and only needs an object exposing ``invoke_llm(...)``. ``PluginModelRuntime``
|
||||
intentionally exposes a lower-level API where provider, model, and
|
||||
credentials are passed per call. This adapter supplies the small bound
|
||||
``invoke_llm`` surface the helper needs without constructing a full
|
||||
``ModelInstance`` or reintroducing model-manager dependencies into the
|
||||
plugin runtime path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
runtime: PluginModelRuntime,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
) -> None:
|
||||
self._runtime = runtime
|
||||
self._provider = provider
|
||||
self._model = model
|
||||
self._credentials = credentials
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
callbacks: object | None = None,
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
del callbacks
|
||||
if stream:
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
return self._runtime.invoke_llm(
|
||||
provider=self._provider,
|
||||
model=self._model,
|
||||
credentials=self._credentials,
|
||||
model_parameters=model_parameters or {},
|
||||
prompt_messages=prompt_messages,
|
||||
tools=list(tools) if tools else None,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
class PluginModelRuntime(ModelRuntime):
|
||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||
|
||||
@ -195,6 +266,34 @@ class PluginModelRuntime(ModelRuntime):
|
||||
|
||||
return schema
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -206,9 +305,9 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
return self.client.invoke_llm(
|
||||
result = self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
plugin_id=plugin_id,
|
||||
@ -221,6 +320,81 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return result
|
||||
|
||||
return normalize_non_stream_runtime_result(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
result=result,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
json_schema: dict[str, Any],
|
||||
model_parameters: dict[str, Any],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
model_schema = self.get_model_schema(
|
||||
provider=provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"Model schema not found for {model}")
|
||||
|
||||
adapter = _PluginStructuredOutputModelInstance(
|
||||
runtime=self,
|
||||
provider=provider,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
)
|
||||
return invoke_llm_with_structured_output_helper(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=cast(Any, adapter),
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
tools=None,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
|
||||
@ -3,13 +3,46 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_manager import ModelManager
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = {
|
||||
ModelType.LLM: LargeLanguageModel,
|
||||
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
|
||||
ModelType.RERANK: RerankModel,
|
||||
ModelType.SPEECH2TEXT: Speech2TextModel,
|
||||
ModelType.MODERATION: ModerationModel,
|
||||
ModelType.TTS: TTSModel,
|
||||
}
|
||||
|
||||
|
||||
def create_model_type_instance(
|
||||
*,
|
||||
runtime: ModelRuntime,
|
||||
provider_schema: ProviderEntity,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
"""Build the graphon model wrapper explicitly against the request runtime."""
|
||||
model_class = _MODEL_CLASS_BY_TYPE.get(model_type)
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
return model_class(provider_schema=provider_schema, model_runtime=runtime)
|
||||
|
||||
|
||||
class PluginModelAssembly:
|
||||
"""Compose request-scoped model views on top of a single plugin runtime."""
|
||||
@ -38,9 +71,22 @@ class PluginModelAssembly:
|
||||
@property
|
||||
def model_provider_factory(self) -> ModelProviderFactory:
|
||||
if self._model_provider_factory is None:
|
||||
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
|
||||
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
|
||||
return self._model_provider_factory
|
||||
|
||||
def create_model_type_instance(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
) -> AIModel:
|
||||
provider_schema = self.model_provider_factory.get_provider_schema(provider=provider)
|
||||
return create_model_type_instance(
|
||||
runtime=self.model_runtime,
|
||||
provider_schema=provider_schema,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_manager(self) -> ProviderManager:
|
||||
if self._provider_manager is None:
|
||||
|
||||
@ -53,24 +53,27 @@ class PromptMessageUtil:
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
elif isinstance(content, AudioPromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "audio",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"format": content.format,
|
||||
}
|
||||
)
|
||||
match content:
|
||||
case TextPromptMessageContent():
|
||||
text += content.data
|
||||
case ImagePromptMessageContent():
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
case AudioPromptMessageContent():
|
||||
files.append(
|
||||
{
|
||||
"type": "audio",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"format": content.format,
|
||||
}
|
||||
)
|
||||
case _:
|
||||
continue
|
||||
else:
|
||||
text = cast(str, prompt_message.content)
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from graphon.model_runtime.protocols.runtime import ModelRuntime
|
||||
|
||||
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
|
||||
|
||||
@ -165,7 +165,7 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
# Get all provider entities
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
|
||||
# Get All preferred provider types of the workspace
|
||||
@ -362,7 +362,7 @@ class ProviderManager:
|
||||
if not default_model:
|
||||
return None
|
||||
|
||||
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
|
||||
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
|
||||
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
|
||||
|
||||
return DefaultModelEntity(
|
||||
|
||||
@ -23,36 +23,37 @@ _TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P<tool_file_id>[^/?#
|
||||
|
||||
|
||||
def safe_json_value(v):
|
||||
if isinstance(v, datetime):
|
||||
tz_name = "UTC"
|
||||
if isinstance(current_user, Account) and current_user.timezone is not None:
|
||||
tz_name = current_user.timezone
|
||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||
elif isinstance(v, date):
|
||||
return v.isoformat()
|
||||
elif isinstance(v, UUID):
|
||||
return str(v)
|
||||
elif isinstance(v, Decimal):
|
||||
return float(v)
|
||||
elif isinstance(v, bytes):
|
||||
try:
|
||||
return v.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return v.hex()
|
||||
elif isinstance(v, memoryview):
|
||||
return v.tobytes().hex()
|
||||
elif isinstance(v, np.integer):
|
||||
return int(v)
|
||||
elif isinstance(v, np.floating):
|
||||
return float(v)
|
||||
elif isinstance(v, np.ndarray):
|
||||
return v.tolist()
|
||||
elif isinstance(v, dict):
|
||||
return safe_json_dict(v)
|
||||
elif isinstance(v, list | tuple | set):
|
||||
return [safe_json_value(i) for i in v]
|
||||
else:
|
||||
return v
|
||||
match v:
|
||||
case datetime():
|
||||
tz_name = "UTC"
|
||||
if isinstance(current_user, Account) and current_user.timezone is not None:
|
||||
tz_name = current_user.timezone
|
||||
return v.astimezone(pytz.timezone(tz_name)).isoformat()
|
||||
case date():
|
||||
return v.isoformat()
|
||||
case UUID():
|
||||
return str(v)
|
||||
case Decimal():
|
||||
return float(v)
|
||||
case bytes():
|
||||
try:
|
||||
return v.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return v.hex()
|
||||
case memoryview():
|
||||
return v.tobytes().hex()
|
||||
case np.integer():
|
||||
return int(v)
|
||||
case np.floating():
|
||||
return float(v)
|
||||
case np.ndarray():
|
||||
return v.tolist()
|
||||
case dict():
|
||||
return safe_json_dict(v)
|
||||
case list() | tuple() | set():
|
||||
return [safe_json_value(i) for i in v]
|
||||
case _:
|
||||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d: dict[str, Any]):
|
||||
|
||||
@ -374,11 +374,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
config_for_node_init: BaseNodeData | dict[str, Any]
|
||||
if isinstance(resolved_node_data, BaseNodeData):
|
||||
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
else:
|
||||
config_for_node_init = resolved_node_data
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@ -446,9 +441,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
return node_class(
|
||||
node_id=node_id,
|
||||
config=config_for_node_init,
|
||||
data=constructor_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
|
||||
@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: AgentNodeData,
|
||||
data: AgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: DatasourceNodeData,
|
||||
data: DatasourceNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -32,14 +32,14 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: KnowledgeIndexNodeData,
|
||||
data: KnowledgeIndexNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -71,14 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: KnowledgeRetrievalNodeData,
|
||||
data: KnowledgeRetrievalNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Protocol, cast
|
||||
from typing import Any, Protocol
|
||||
from uuid import uuid4
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs:
|
||||
normalized = _normalize_system_variable_values(values, **kwargs)
|
||||
|
||||
return [
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=system_variable_selector(key),
|
||||
name=key,
|
||||
)
|
||||
for key, value in normalized.items()
|
||||
]
|
||||
@ -130,13 +127,10 @@ def build_bootstrap_variables(
|
||||
|
||||
for node_id, value in rag_pipeline_variables_map.items():
|
||||
variables.append(
|
||||
cast(
|
||||
Variable,
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
),
|
||||
segment_to_variable(
|
||||
segment=build_segment(value),
|
||||
selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id),
|
||||
name=node_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -46,6 +46,11 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, *, tenant_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@staticmethod
|
||||
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
|
||||
"""
|
||||
@ -107,7 +112,7 @@ class _WorkflowChildEngineBuilder:
|
||||
config=config,
|
||||
child_engine_builder=self,
|
||||
)
|
||||
child_engine.layer(LLMQuotaLayer())
|
||||
child_engine.layer(LLMQuotaLayer(tenant_id=self.tenant_id))
|
||||
return child_engine
|
||||
|
||||
|
||||
@ -176,7 +181,7 @@ class WorkflowEntry:
|
||||
self.command_channel = command_channel
|
||||
execution_context = capture_current_context()
|
||||
graph_runtime_state.execution_context = execution_context
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder()
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id)
|
||||
self.graph_engine = GraphEngine(
|
||||
workflow_id=workflow_id,
|
||||
graph=graph,
|
||||
@ -208,7 +213,7 @@ class WorkflowEntry:
|
||||
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
self.graph_engine.layer(limits_layer)
|
||||
self.graph_engine.layer(LLMQuotaLayer())
|
||||
self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id))
|
||||
|
||||
# Add observability layer when OTel is enabled
|
||||
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
||||
|
||||
95
api/dev/generate_fastopenapi_specs.py
Normal file
95
api/dev/generate_fastopenapi_specs.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Generate FastOpenAPI OpenAPI 3.0 specs without booting the full backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
API_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(API_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(API_ROOT))
|
||||
|
||||
from dev.generate_swagger_specs import apply_runtime_defaults, drop_null_values, sort_openapi_arrays
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FastOpenApiSpecTarget:
|
||||
route: str
|
||||
filename: str
|
||||
|
||||
|
||||
FASTOPENAPI_SPEC_TARGETS: tuple[FastOpenApiSpecTarget, ...] = (
|
||||
FastOpenApiSpecTarget(route="/fastopenapi/openapi.json", filename="fastopenapi-console-openapi.json"),
|
||||
)
|
||||
|
||||
|
||||
def create_fastopenapi_spec_app():
|
||||
"""Build a minimal Flask app that only mounts FastOpenAPI docs routes."""
|
||||
|
||||
apply_runtime_defaults()
|
||||
|
||||
from app_factory import create_flask_app_with_configs
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
app = create_flask_app_with_configs()
|
||||
ext_fastopenapi.init_app(app)
|
||||
return app
|
||||
|
||||
|
||||
def generate_fastopenapi_specs(output_dir: Path) -> list[Path]:
|
||||
"""Write FastOpenAPI specs to `output_dir` and return the written paths."""
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
app = create_fastopenapi_spec_app()
|
||||
client = app.test_client()
|
||||
|
||||
written_paths: list[Path] = []
|
||||
for target in FASTOPENAPI_SPEC_TARGETS:
|
||||
response = client.get(target.route)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"failed to fetch {target.route}: {response.status_code}")
|
||||
|
||||
payload = response.get_json()
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"unexpected response payload for {target.route}")
|
||||
payload = drop_null_values(payload)
|
||||
payload = sort_openapi_arrays(payload)
|
||||
|
||||
output_path = output_dir / target.filename
|
||||
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
written_paths.append(output_path)
|
||||
|
||||
return written_paths
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("openapi"),
|
||||
help="Directory where the OpenAPI JSON files will be written.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
written_paths = generate_fastopenapi_specs(args.output_dir)
|
||||
|
||||
for path in written_paths:
|
||||
logger.debug(path)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
161
api/dev/generate_swagger_markdown_docs.py
Normal file
161
api/dev/generate_swagger_markdown_docs.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""Generate OpenAPI JSON specs and split Markdown API docs.
|
||||
|
||||
The Markdown step uses `swagger-markdown`, the same converter family as the
|
||||
Swagger Markdown UI, so CI and local regeneration catch converter-incompatible
|
||||
OpenAPI output early.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
API_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(API_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(API_ROOT))
|
||||
|
||||
from dev.generate_fastopenapi_specs import FASTOPENAPI_SPEC_TARGETS, generate_fastopenapi_specs
|
||||
from dev.generate_swagger_specs import SPEC_TARGETS, generate_specs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SWAGGER_MARKDOWN_PACKAGE = "swagger-markdown@3.0.0"
|
||||
CONSOLE_SWAGGER_FILENAME = "console-swagger.json"
|
||||
STALE_COMBINED_MARKDOWN_FILENAME = "api-reference.md"
|
||||
|
||||
|
||||
def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None:
|
||||
subprocess.run(
|
||||
[
|
||||
"npx",
|
||||
"--yes",
|
||||
SWAGGER_MARKDOWN_PACKAGE,
|
||||
"-i",
|
||||
str(spec_path),
|
||||
"-o",
|
||||
str(markdown_path),
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def _demote_markdown_headings(markdown: str, *, levels: int = 1) -> str:
|
||||
"""Nest generated Markdown under another Markdown section."""
|
||||
|
||||
heading_prefix = "#" * levels
|
||||
lines = []
|
||||
for line in markdown.splitlines():
|
||||
if line.startswith("#"):
|
||||
lines.append(f"{heading_prefix}{line}")
|
||||
else:
|
||||
lines.append(line)
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def _append_fastopenapi_markdown(console_markdown_path: Path, fastopenapi_markdown_path: Path) -> None:
|
||||
"""Append FastOpenAPI console docs to the existing console API Markdown."""
|
||||
|
||||
console_markdown = console_markdown_path.read_text(encoding="utf-8").rstrip()
|
||||
fastopenapi_markdown = _demote_markdown_headings(
|
||||
fastopenapi_markdown_path.read_text(encoding="utf-8"),
|
||||
levels=2,
|
||||
)
|
||||
console_markdown_path.write_text(
|
||||
"\n\n".join(
|
||||
[
|
||||
console_markdown,
|
||||
"## FastOpenAPI Preview (OpenAPI 3.0)",
|
||||
fastopenapi_markdown,
|
||||
]
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def generate_markdown_docs(
|
||||
swagger_dir: Path,
|
||||
markdown_dir: Path,
|
||||
*,
|
||||
keep_swagger_json: bool = False,
|
||||
) -> list[Path]:
|
||||
"""Generate intermediate specs, convert them to split Markdown API docs, and return Markdown paths."""
|
||||
|
||||
swagger_paths = generate_specs(swagger_dir)
|
||||
fastopenapi_paths = generate_fastopenapi_specs(swagger_dir)
|
||||
spec_paths = [*swagger_paths, *fastopenapi_paths]
|
||||
swagger_paths_by_name = {path.name: path for path in swagger_paths}
|
||||
fastopenapi_paths_by_name = {path.name: path for path in fastopenapi_paths}
|
||||
|
||||
markdown_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
written_paths: list[Path] = []
|
||||
try:
|
||||
with tempfile.TemporaryDirectory(prefix="dify-api-docs-") as temp_dir:
|
||||
temp_markdown_dir = Path(temp_dir)
|
||||
|
||||
for target in SPEC_TARGETS:
|
||||
swagger_path = swagger_paths_by_name[target.filename]
|
||||
markdown_path = markdown_dir / f"{swagger_path.stem}.md"
|
||||
_convert_spec_to_markdown(swagger_path, markdown_path)
|
||||
written_paths.append(markdown_path)
|
||||
|
||||
for target in FASTOPENAPI_SPEC_TARGETS: # type: ignore
|
||||
fastopenapi_path = fastopenapi_paths_by_name[target.filename]
|
||||
markdown_path = temp_markdown_dir / f"{fastopenapi_path.stem}.md"
|
||||
_convert_spec_to_markdown(fastopenapi_path, markdown_path)
|
||||
|
||||
console_markdown_path = markdown_dir / f"{Path(CONSOLE_SWAGGER_FILENAME).stem}.md"
|
||||
_append_fastopenapi_markdown(console_markdown_path, markdown_path)
|
||||
|
||||
(markdown_dir / STALE_COMBINED_MARKDOWN_FILENAME).unlink(missing_ok=True)
|
||||
finally:
|
||||
if not keep_swagger_json:
|
||||
for path in spec_paths:
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
return written_paths
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--swagger-dir",
|
||||
type=Path,
|
||||
default=Path("openapi"),
|
||||
help="Directory where intermediate JSON spec files will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--markdown-dir",
|
||||
type=Path,
|
||||
default=Path("openapi/markdown"),
|
||||
help="Directory where split Markdown API docs will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-swagger-json",
|
||||
action="store_true",
|
||||
help="Keep intermediate JSON spec files after Markdown generation.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
written_paths = generate_markdown_docs(
|
||||
args.swagger_dir,
|
||||
args.markdown_dir,
|
||||
keep_swagger_json=args.keep_swagger_json,
|
||||
)
|
||||
|
||||
for path in written_paths:
|
||||
logger.debug(path)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -9,12 +9,15 @@ which is unnecessary when the goal is only to serialize the Flask-RESTX
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Protocol, TypeGuard
|
||||
|
||||
from flask import Flask
|
||||
from flask_restx.swagger import Swagger
|
||||
@ -30,19 +33,110 @@ if str(API_ROOT) not in sys.path:
|
||||
class SpecTarget:
|
||||
route: str
|
||||
filename: str
|
||||
namespace: str
|
||||
|
||||
|
||||
class RestxApi(Protocol):
|
||||
models: MutableMapping[str, object]
|
||||
|
||||
def model(self, name: str, model: dict[object, object]) -> object: ...
|
||||
|
||||
|
||||
SPEC_TARGETS: tuple[SpecTarget, ...] = (
|
||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json"),
|
||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json"),
|
||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json"),
|
||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
|
||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
|
||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
|
||||
)
|
||||
|
||||
_ORIGINAL_REGISTER_MODEL = Swagger.register_model
|
||||
_ORIGINAL_REGISTER_FIELD = Swagger.register_field
|
||||
|
||||
|
||||
def _apply_runtime_defaults() -> None:
|
||||
def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]:
|
||||
"""Return whether a nested field map is an anonymous inline mapping."""
|
||||
|
||||
from flask_restx.model import Model, OrderedModel
|
||||
|
||||
return isinstance(value, dict) and not isinstance(value, (Model, OrderedModel))
|
||||
|
||||
|
||||
def _jsonable_schema_value(value: object) -> object:
|
||||
"""Return a deterministic JSON-serializable representation for schema fingerprints."""
|
||||
|
||||
if value is None or isinstance(value, str | int | float | bool):
|
||||
return value
|
||||
if isinstance(value, list | tuple):
|
||||
return [_jsonable_schema_value(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(key): _jsonable_schema_value(item) for key, item in value.items()}
|
||||
value_type = type(value)
|
||||
return f"<{value_type.__module__}.{value_type.__qualname__}>"
|
||||
|
||||
|
||||
def _field_signature(field: object) -> object:
|
||||
"""Build a stable signature for a Flask-RESTX field object."""
|
||||
|
||||
from flask_restx import fields
|
||||
from flask_restx.model import instance
|
||||
|
||||
field_instance = instance(field)
|
||||
signature: dict[str, object] = {
|
||||
"class": f"{field_instance.__class__.__module__}.{field_instance.__class__.__qualname__}"
|
||||
}
|
||||
|
||||
if isinstance(field_instance, fields.Nested):
|
||||
nested = getattr(field_instance, "nested", None)
|
||||
if _is_inline_field_map(nested):
|
||||
signature["nested"] = _inline_model_signature(nested)
|
||||
else:
|
||||
signature["nested"] = getattr(
|
||||
nested,
|
||||
"name",
|
||||
f"<{type(nested).__module__}.{type(nested).__qualname__}>",
|
||||
)
|
||||
elif hasattr(field_instance, "container"):
|
||||
signature["container"] = _field_signature(field_instance.container)
|
||||
else:
|
||||
schema = getattr(field_instance, "__schema__", None)
|
||||
if isinstance(schema, dict):
|
||||
signature["schema"] = _jsonable_schema_value(schema)
|
||||
|
||||
for attr_name in (
|
||||
"attribute",
|
||||
"default",
|
||||
"description",
|
||||
"example",
|
||||
"max",
|
||||
"min",
|
||||
"nullable",
|
||||
"readonly",
|
||||
"required",
|
||||
"title",
|
||||
):
|
||||
if hasattr(field_instance, attr_name):
|
||||
signature[attr_name] = _jsonable_schema_value(getattr(field_instance, attr_name))
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
def _inline_model_signature(nested_fields: dict[object, object]) -> object:
|
||||
"""Build a stable signature for an anonymous inline model."""
|
||||
|
||||
return [
|
||||
(str(field_name), _field_signature(field))
|
||||
for field_name, field in sorted(nested_fields.items(), key=lambda item: str(item[0]))
|
||||
]
|
||||
|
||||
|
||||
def _inline_model_name(nested_fields: dict[object, object]) -> str:
|
||||
"""Return a stable Swagger model name for an anonymous inline field map."""
|
||||
|
||||
signature = json.dumps(_inline_model_signature(nested_fields), sort_keys=True, separators=(",", ":"))
|
||||
digest = hashlib.sha1(signature.encode("utf-8")).hexdigest()[:12]
|
||||
return f"_AnonymousInlineModel_{digest}"
|
||||
|
||||
|
||||
def apply_runtime_defaults() -> None:
|
||||
"""Force the small config surface required for Swagger generation."""
|
||||
|
||||
os.environ.setdefault("SECRET_KEY", "spec-export")
|
||||
@ -74,25 +168,26 @@ def _patch_swagger_for_inline_nested_dicts() -> None:
|
||||
anonymous_models = getattr(self, "_anonymous_inline_models", None)
|
||||
if anonymous_models is None:
|
||||
anonymous_models = {}
|
||||
self._anonymous_inline_models = anonymous_models
|
||||
self.__dict__["_anonymous_inline_models"] = anonymous_models
|
||||
|
||||
anonymous_name = anonymous_models.get(id(nested_fields))
|
||||
if anonymous_name is None:
|
||||
anonymous_name = f"_AnonymousInlineModel{len(anonymous_models) + 1}"
|
||||
anonymous_name = _inline_model_name(nested_fields)
|
||||
anonymous_models[id(nested_fields)] = anonymous_name
|
||||
self.api.model(anonymous_name, nested_fields)
|
||||
if anonymous_name not in self.api.models:
|
||||
self.api.model(anonymous_name, nested_fields)
|
||||
|
||||
return self.api.models[anonymous_name]
|
||||
|
||||
def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]:
|
||||
if isinstance(model, dict):
|
||||
if _is_inline_field_map(model):
|
||||
model = get_or_create_inline_model(self, model)
|
||||
|
||||
return _ORIGINAL_REGISTER_MODEL(self, model)
|
||||
|
||||
def register_field_with_inline_dict_support(self: Swagger, field: object) -> None:
|
||||
nested = getattr(field, "nested", None)
|
||||
if isinstance(nested, dict):
|
||||
if _is_inline_field_map(nested):
|
||||
field.model = get_or_create_inline_model(self, nested) # type: ignore
|
||||
|
||||
_ORIGINAL_REGISTER_FIELD(self, field)
|
||||
@ -105,22 +200,169 @@ def _patch_swagger_for_inline_nested_dicts() -> None:
|
||||
def create_spec_app() -> Flask:
|
||||
"""Build a minimal Flask app that only mounts the Swagger-producing blueprints."""
|
||||
|
||||
_apply_runtime_defaults()
|
||||
apply_runtime_defaults()
|
||||
_patch_swagger_for_inline_nested_dicts()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
from controllers.console import bp as console_bp
|
||||
from controllers.console import console_ns
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.web import bp as web_bp
|
||||
from controllers.web import web_ns
|
||||
|
||||
app.register_blueprint(console_bp)
|
||||
app.register_blueprint(web_bp)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
for namespace in (console_ns, web_ns, service_api_ns):
|
||||
for api in namespace.apis:
|
||||
_materialize_inline_model_definitions(api)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _registered_models(namespace: str) -> dict[str, object]:
|
||||
"""Return the Flask-RESTX models registered for a Swagger namespace."""
|
||||
|
||||
if namespace == "console":
|
||||
from controllers.console import console_ns
|
||||
|
||||
models = dict(console_ns.models)
|
||||
for api in console_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
if namespace == "web":
|
||||
from controllers.web import web_ns
|
||||
|
||||
models = dict(web_ns.models)
|
||||
for api in web_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
if namespace == "service":
|
||||
from controllers.service_api import service_api_ns
|
||||
|
||||
models = dict(service_api_ns.models)
|
||||
for api in service_api_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
|
||||
raise ValueError(f"unknown Swagger namespace: {namespace}")
|
||||
|
||||
|
||||
def _materialize_inline_model_definitions(api: RestxApi) -> None:
|
||||
"""Convert inline `fields.Nested({...})` maps into named API models."""
|
||||
|
||||
from flask_restx import fields
|
||||
from flask_restx.model import Model, OrderedModel, instance
|
||||
|
||||
inline_models: dict[int, dict[object, object]] = {}
|
||||
inline_model_names: dict[int, str] = {}
|
||||
|
||||
def collect_field(field: object) -> None:
|
||||
field_instance = instance(field)
|
||||
if isinstance(field_instance, fields.Nested):
|
||||
nested = getattr(field_instance, "nested", None)
|
||||
if _is_inline_field_map(nested) and id(nested) not in inline_models:
|
||||
inline_models[id(nested)] = nested
|
||||
for nested_field in nested.values():
|
||||
collect_field(nested_field)
|
||||
|
||||
container = getattr(field_instance, "container", None)
|
||||
if container is not None:
|
||||
collect_field(container)
|
||||
|
||||
for model in list(api.models.values()):
|
||||
if isinstance(model, (Model, OrderedModel)):
|
||||
for field in model.values():
|
||||
collect_field(field)
|
||||
|
||||
for nested_fields in sorted(inline_models.values(), key=_inline_model_name):
|
||||
anonymous_name = _inline_model_name(nested_fields)
|
||||
inline_model_names[id(nested_fields)] = anonymous_name
|
||||
if anonymous_name not in api.models:
|
||||
api.model(anonymous_name, nested_fields)
|
||||
|
||||
def model_name_for(nested_fields: dict[object, object]) -> str:
|
||||
anonymous_name = inline_model_names.get(id(nested_fields))
|
||||
if anonymous_name is None:
|
||||
anonymous_name = _inline_model_name(nested_fields)
|
||||
inline_model_names[id(nested_fields)] = anonymous_name
|
||||
if anonymous_name not in api.models:
|
||||
api.model(anonymous_name, nested_fields)
|
||||
return anonymous_name
|
||||
|
||||
def materialize_field(field: object) -> None:
|
||||
field_instance = instance(field)
|
||||
if isinstance(field_instance, fields.Nested):
|
||||
nested = getattr(field_instance, "nested", None)
|
||||
if _is_inline_field_map(nested):
|
||||
field_instance.model = api.models[model_name_for(nested)] # type: ignore[attr-defined]
|
||||
|
||||
container = getattr(field_instance, "container", None)
|
||||
if container is not None:
|
||||
materialize_field(container)
|
||||
|
||||
index = 0
|
||||
while index < len(api.models):
|
||||
model = list(api.models.values())[index]
|
||||
index += 1
|
||||
if isinstance(model, (Model, OrderedModel)):
|
||||
for field in model.values():
|
||||
materialize_field(field)
|
||||
|
||||
|
||||
def drop_null_values(value: object) -> object:
|
||||
"""Remove JSON null values that make the Markdown converter crash."""
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {key: drop_null_values(item) for key, item in value.items() if item is not None}
|
||||
if isinstance(value, list):
|
||||
return [drop_null_values(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def sort_openapi_arrays(value: object, *, parent_key: str | None = None) -> object:
|
||||
"""Sort order-insensitive Swagger arrays so generated Markdown is stable."""
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {key: sort_openapi_arrays(item, parent_key=key) for key, item in value.items()}
|
||||
if not isinstance(value, list):
|
||||
return value
|
||||
|
||||
sorted_items = [sort_openapi_arrays(item, parent_key=parent_key) for item in value]
|
||||
if parent_key == "parameters":
|
||||
return sorted(
|
||||
sorted_items,
|
||||
key=lambda item: (
|
||||
item.get("in", "") if isinstance(item, dict) else "",
|
||||
item.get("name", "") if isinstance(item, dict) else "",
|
||||
json.dumps(item, sort_keys=True, default=str),
|
||||
),
|
||||
)
|
||||
if parent_key in {"enum", "required", "schemes", "tags"}:
|
||||
string_items = [item for item in sorted_items if isinstance(item, str)]
|
||||
if len(string_items) == len(sorted_items):
|
||||
return sorted(string_items)
|
||||
return sorted_items
|
||||
|
||||
|
||||
def _merge_registered_definitions(payload: dict[str, object], namespace: str) -> dict[str, object]:
|
||||
"""Include registered but route-indirect models in the exported Swagger definitions."""
|
||||
|
||||
definitions = payload.setdefault("definitions", {})
|
||||
if not isinstance(definitions, dict):
|
||||
raise RuntimeError("unexpected Swagger definitions payload")
|
||||
|
||||
for name, model in _registered_models(namespace).items():
|
||||
schema = getattr(model, "__schema__", None)
|
||||
if isinstance(schema, dict):
|
||||
definitions.setdefault(name, schema)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def generate_specs(output_dir: Path) -> list[Path]:
|
||||
"""Write all Swagger specs to `output_dir` and return the written paths."""
|
||||
|
||||
@ -138,6 +380,9 @@ def generate_specs(output_dir: Path) -> list[Path]:
|
||||
payload = response.get_json()
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"unexpected response payload for {target.route}")
|
||||
payload = _merge_registered_definitions(payload, target.namespace)
|
||||
payload = drop_null_values(payload)
|
||||
payload = sort_openapi_arrays(payload)
|
||||
|
||||
output_path = output_dir / target.filename
|
||||
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
|
||||
@ -137,17 +137,13 @@ def handle(sender: Message, **kwargs):
|
||||
if used_quota is not None:
|
||||
match provider_configuration.system_configuration.current_quota_type:
|
||||
case ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
_deduct_credit_pool_quota_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="trial",
|
||||
)
|
||||
case ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
_deduct_credit_pool_quota_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
@ -200,6 +196,26 @@ def handle(sender: Message, **kwargs):
|
||||
raise
|
||||
|
||||
|
||||
def _deduct_credit_pool_quota_capped(*, tenant_id: str, credits_required: int, pool_type: str) -> None:
|
||||
"""Apply post-generation credit accounting without failing message persistence on quota exhaustion."""
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=credits_required,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
if deducted_credits < credits_required:
|
||||
logger.warning(
|
||||
"Credit pool exhausted during message-created accounting, "
|
||||
"tenant_id=%s, pool_type=%s, credits_required=%s, credits_deducted=%s",
|
||||
tenant_id,
|
||||
pool_type,
|
||||
credits_required,
|
||||
deducted_credits,
|
||||
)
|
||||
|
||||
|
||||
def _calculate_quota_usage(
|
||||
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||
) -> int | None:
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
"""add recommended app categories
|
||||
|
||||
Revision ID: a4f2d8c9b731
|
||||
Revises: 227822d22895
|
||||
Create Date: 2026-04-29 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a4f2d8c9b731"
|
||||
down_revision = "227822d22895"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("recommended_apps", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("categories", sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("recommended_apps", schema=None) as batch_op:
|
||||
batch_op.drop_column("categories")
|
||||
@ -878,6 +878,7 @@ class RecommendedApp(TypeBase):
|
||||
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
categories: Mapped[list[str] | None] = mapped_column(sa.JSON, nullable=True, default=None)
|
||||
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
|
||||
14766
api/openapi/markdown/console-swagger.md
Normal file
14766
api/openapi/markdown/console-swagger.md
Normal file
File diff suppressed because it is too large
Load Diff
2754
api/openapi/markdown/service-swagger.md
Normal file
2754
api/openapi/markdown/service-swagger.md
Normal file
File diff suppressed because it is too large
Load Diff
1224
api/openapi/markdown/web-swagger.md
Normal file
1224
api/openapi/markdown/web-swagger.md
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from dify_trace_aliyun.entities.semconv import (
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_SESSION_ID,
|
||||
@ -31,7 +32,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models import EndUser
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_no_end_user(monkeypatch):
|
||||
def test_get_user_id_from_message_data_no_end_user(monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = None
|
||||
@ -39,7 +40,7 @@ def test_get_user_id_from_message_data_no_end_user(monkeypatch):
|
||||
assert get_user_id_from_message_data(message_data) == "account_id"
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
def test_get_user_id_from_message_data_with_end_user(monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
@ -57,7 +58,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
|
||||
assert get_user_id_from_message_data(message_data) == "session_id"
|
||||
|
||||
|
||||
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
|
||||
def test_get_user_id_from_message_data_end_user_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.from_account_id = "account_id"
|
||||
message_data.from_end_user_id = "end_user_id"
|
||||
@ -111,7 +112,7 @@ def test_get_workflow_node_status():
|
||||
assert status.status_code == StatusCode.UNSET
|
||||
|
||||
|
||||
def test_create_links_from_trace_id(monkeypatch):
|
||||
def test_create_links_from_trace_id(monkeypatch: pytest.MonkeyPatch):
|
||||
# Mock create_link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
import dify_trace_aliyun.data_exporter.traceclient
|
||||
|
||||
@ -40,7 +40,7 @@ def langfuse_config():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(langfuse_config, monkeypatch):
|
||||
def trace_instance(langfuse_config, monkeypatch: pytest.MonkeyPatch):
|
||||
# Mock Langfuse client to avoid network calls
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
|
||||
@ -49,7 +49,7 @@ def trace_instance(langfuse_config, monkeypatch):
|
||||
return instance
|
||||
|
||||
|
||||
def test_init(langfuse_config, monkeypatch):
|
||||
def test_init(langfuse_config, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_langfuse = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
@ -64,7 +64,7 @@ def test_init(langfuse_config, monkeypatch):
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
@ -114,7 +114,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Setup trace info
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
@ -218,7 +218,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
assert other_span.level == LevelEnum.ERROR
|
||||
|
||||
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
@ -259,7 +259,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
assert trace_data.name == TraceTaskName.WORKFLOW_TRACE
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
tenant_id="tenant-1",
|
||||
@ -287,7 +287,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
def test_message_trace_basic(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
@ -331,7 +331,7 @@ def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
assert gen_data.usage.total == 30
|
||||
|
||||
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
@ -636,7 +636,7 @@ def test_langfuse_trace_entity_with_list_dict_input():
|
||||
assert data.input[0]["content"] == "hello"
|
||||
|
||||
|
||||
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog):
|
||||
def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
|
||||
# Setup trace info to trigger LLM node usage extraction
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="wf-1",
|
||||
|
||||
@ -35,7 +35,7 @@ def langsmith_config():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(langsmith_config, monkeypatch):
|
||||
def trace_instance(langsmith_config, monkeypatch: pytest.MonkeyPatch):
|
||||
# Mock LangSmith client
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
|
||||
@ -44,7 +44,7 @@ def trace_instance(langsmith_config, monkeypatch):
|
||||
return instance
|
||||
|
||||
|
||||
def test_init(langsmith_config, monkeypatch):
|
||||
def test_init(langsmith_config, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_client_class = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
@ -57,7 +57,7 @@ def test_init(langsmith_config, monkeypatch):
|
||||
assert instance.file_base_url == "http://test.url"
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
@ -107,7 +107,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace(trace_instance, monkeypatch):
|
||||
def test_workflow_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Setup trace info
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
@ -223,7 +223,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
|
||||
assert call_args[4].run_type == LangSmithRunType.retriever
|
||||
|
||||
|
||||
def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_no_start_time(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
@ -266,7 +266,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
|
||||
assert trace_instance.add_run.called
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.trace_id = "trace-1"
|
||||
trace_info.message_id = None
|
||||
@ -290,7 +290,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace(trace_instance, monkeypatch):
|
||||
def test_message_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "msg-1"
|
||||
message_data.from_account_id = "acc-1"
|
||||
@ -516,7 +516,7 @@ def test_update_run_error(trace_instance):
|
||||
trace_instance.update_run(update_data)
|
||||
|
||||
|
||||
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog):
|
||||
def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
|
||||
@ -614,7 +614,7 @@ class TestMessageTrace:
|
||||
span.set_status.assert_called_once()
|
||||
span.add_event.assert_called_once()
|
||||
|
||||
def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch):
|
||||
def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch: pytest.MonkeyPatch):
|
||||
span = MagicMock()
|
||||
mock_tracing["start"].return_value = span
|
||||
mock_tracing["set"].return_value = "token"
|
||||
|
||||
@ -35,7 +35,7 @@ def opik_config():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_instance(opik_config, monkeypatch):
|
||||
def trace_instance(opik_config, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client)
|
||||
|
||||
@ -65,7 +65,7 @@ def test_prepare_opik_uuid():
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_init(opik_config, monkeypatch):
|
||||
def test_init(opik_config, monkeypatch: pytest.MonkeyPatch):
|
||||
mock_opik = MagicMock()
|
||||
monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik)
|
||||
monkeypatch.setenv("FILES_URL", "http://test.url")
|
||||
@ -82,7 +82,7 @@ def test_init(opik_config, monkeypatch):
|
||||
assert instance.project == opik_config.project
|
||||
|
||||
|
||||
def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
def test_trace_dispatch(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
methods = [
|
||||
"workflow_trace",
|
||||
"message_trace",
|
||||
@ -132,7 +132,7 @@ def test_trace_dispatch(trace_instance, monkeypatch):
|
||||
mocks["generate_name_trace"].assert_called_once_with(info)
|
||||
|
||||
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Define constants for better readability
|
||||
WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce"
|
||||
WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef"
|
||||
@ -221,7 +221,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
|
||||
assert trace_instance.add_span.call_count >= 1
|
||||
|
||||
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_no_message_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Define constants for better readability
|
||||
WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3"
|
||||
WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac"
|
||||
@ -265,7 +265,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
|
||||
trace_instance.add_trace.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
def test_workflow_trace_missing_app_id(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a",
|
||||
tenant_id="tenant-1",
|
||||
@ -293,7 +293,7 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
|
||||
def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
def test_message_trace_basic(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
# Define constants for better readability
|
||||
MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab"
|
||||
CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3"
|
||||
@ -340,7 +340,7 @@ def test_message_trace_basic(trace_instance, monkeypatch):
|
||||
trace_instance.add_span.assert_called_once()
|
||||
|
||||
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch):
|
||||
def test_message_trace_with_end_user(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
message_data = MagicMock()
|
||||
message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e"
|
||||
message_data.from_account_id = "acc-1"
|
||||
@ -614,7 +614,7 @@ def test_get_project_url_error(trace_instance):
|
||||
trace_instance.get_project_url()
|
||||
|
||||
|
||||
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog):
|
||||
def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch: pytest.MonkeyPatch, caplog):
|
||||
trace_info = WorkflowTraceInfo(
|
||||
workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de",
|
||||
tenant_id="66e8e918-472e-4b69-8051-12502c34fc07",
|
||||
|
||||
@ -267,14 +267,14 @@ class TestInit:
|
||||
with pytest.raises(ValueError, match="Weave login failed"):
|
||||
WeaveDataTrace(config)
|
||||
|
||||
def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch):
|
||||
def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FILES_URL is read from environment."""
|
||||
monkeypatch.setenv("FILES_URL", "http://files.example.com")
|
||||
config = _make_weave_config()
|
||||
instance = WeaveDataTrace(config)
|
||||
assert instance.file_base_url == "http://files.example.com"
|
||||
|
||||
def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch):
|
||||
def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FILES_URL defaults to http://127.0.0.1:5001."""
|
||||
monkeypatch.delenv("FILES_URL", raising=False)
|
||||
config = _make_weave_config()
|
||||
@ -302,7 +302,7 @@ class TestGetProjectUrl:
|
||||
url = instance.get_project_url()
|
||||
assert url == "https://wandb.ai/my-project"
|
||||
|
||||
def test_get_project_url_exception_raises(self, trace_instance, monkeypatch):
|
||||
def test_get_project_url_exception_raises(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Raises ValueError when exception occurs in get_project_url."""
|
||||
monkeypatch.setattr(trace_instance, "entity", None)
|
||||
monkeypatch.setattr(trace_instance, "project_name", None)
|
||||
@ -583,7 +583,7 @@ class TestFinishCall:
|
||||
|
||||
|
||||
class TestWorkflowTrace:
|
||||
def _setup_repo(self, monkeypatch, nodes=None):
|
||||
def _setup_repo(self, monkeypatch: pytest.MonkeyPatch, nodes=None):
|
||||
"""Helper to patch session/repo dependencies."""
|
||||
if nodes is None:
|
||||
nodes = []
|
||||
@ -599,7 +599,7 @@ class TestWorkflowTrace:
|
||||
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
|
||||
return repo
|
||||
|
||||
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Workflow trace with no nodes and no message_id."""
|
||||
self._setup_repo(monkeypatch, nodes=[])
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
@ -614,7 +614,7 @@ class TestWorkflowTrace:
|
||||
assert trace_instance.start_call.call_count == 1
|
||||
assert trace_instance.finish_call.call_count == 1
|
||||
|
||||
def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Workflow trace with message_id creates both message and workflow runs."""
|
||||
self._setup_repo(monkeypatch, nodes=[])
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
@ -629,7 +629,7 @@ class TestWorkflowTrace:
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
assert trace_instance.finish_call.call_count == 2
|
||||
|
||||
def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Workflow trace iterates node executions and creates node runs."""
|
||||
node = _make_node(
|
||||
id="node-1",
|
||||
@ -652,7 +652,7 @@ class TestWorkflowTrace:
|
||||
# workflow run + node run = 2 calls
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
|
||||
def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""LLM node uses process_data prompts as inputs."""
|
||||
node = _make_node(
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
@ -680,7 +680,7 @@ class TestWorkflowTrace:
|
||||
# The key "messages" should be present (validator transforms the list)
|
||||
assert "messages" in node_run.inputs
|
||||
|
||||
def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Non-LLM node uses node_execution.inputs directly."""
|
||||
node = _make_node(
|
||||
node_type=BuiltinNodeTypes.TOOL,
|
||||
@ -701,7 +701,7 @@ class TestWorkflowTrace:
|
||||
node_run = node_call_args[0][0]
|
||||
assert node_run.inputs.get("tool_input") == "val"
|
||||
|
||||
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Raises ValueError when app_id is missing from metadata."""
|
||||
monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
|
||||
@ -714,7 +714,7 @@ class TestWorkflowTrace:
|
||||
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""start_time defaults to datetime.now() when None."""
|
||||
self._setup_repo(monkeypatch, nodes=[])
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
@ -727,7 +727,7 @@ class TestWorkflowTrace:
|
||||
|
||||
assert trace_instance.start_call.call_count == 1
|
||||
|
||||
def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Node with created_at=None uses datetime.now()."""
|
||||
node = _make_node(created_at=None, elapsed_time=0.5)
|
||||
self._setup_repo(monkeypatch, nodes=[node])
|
||||
@ -740,7 +740,7 @@ class TestWorkflowTrace:
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
|
||||
def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Chat mode LLM node adds ls_provider and ls_model_name to attributes."""
|
||||
node = _make_node(
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
@ -765,7 +765,7 @@ class TestWorkflowTrace:
|
||||
assert node_run.attributes.get("ls_provider") == "openai"
|
||||
assert node_run.attributes.get("ls_model_name") == "gpt-4"
|
||||
|
||||
def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch):
|
||||
def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Nodes are sorted by created_at before processing."""
|
||||
node1 = _make_node(id="node-b", created_at=_dt() + timedelta(seconds=2))
|
||||
node2 = _make_node(id="node-a", created_at=_dt())
|
||||
@ -799,7 +799,7 @@ class TestMessageTrace:
|
||||
trace_instance.message_trace(trace_info)
|
||||
trace_instance.start_call.assert_not_called()
|
||||
|
||||
def test_basic_message_trace(self, trace_instance, monkeypatch):
|
||||
def test_basic_message_trace(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace creates message run and llm child run."""
|
||||
monkeypatch.setattr(
|
||||
"dify_trace_weave.weave_trace.db.session.get",
|
||||
@ -816,7 +816,7 @@ class TestMessageTrace:
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
assert trace_instance.finish_call.call_count == 2
|
||||
|
||||
def test_message_trace_with_file_data(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_with_file_data(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace appends file URL to file_list."""
|
||||
file_data = MagicMock()
|
||||
file_data.url = "path/to/file.png"
|
||||
@ -839,7 +839,7 @@ class TestMessageTrace:
|
||||
message_run = trace_instance.start_call.call_args_list[0][0][0]
|
||||
assert "http://files.test/path/to/file.png" in message_run.file_list
|
||||
|
||||
def test_message_trace_with_end_user(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_with_end_user(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace looks up end user and sets end_user_id attribute."""
|
||||
end_user = MagicMock()
|
||||
end_user.session_id = "session-xyz"
|
||||
@ -862,7 +862,7 @@ class TestMessageTrace:
|
||||
message_run = trace_instance.start_call.call_args_list[0][0][0]
|
||||
assert message_run.attributes.get("end_user_id") == "session-xyz"
|
||||
|
||||
def test_message_trace_no_end_user(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_no_end_user(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace handles when from_end_user_id is None."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.get.return_value = None
|
||||
@ -880,7 +880,7 @@ class TestMessageTrace:
|
||||
trace_instance.message_trace(trace_info)
|
||||
assert trace_instance.start_call.call_count == 2
|
||||
|
||||
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""trace_id falls back to message_id when trace_id is None."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.get.return_value = None
|
||||
@ -895,7 +895,7 @@ class TestMessageTrace:
|
||||
message_run = trace_instance.start_call.call_args_list[0][0][0]
|
||||
assert message_run.id == "msg-1"
|
||||
|
||||
def test_message_trace_file_list_none(self, trace_instance, monkeypatch):
|
||||
def test_message_trace_file_list_none(self, trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
"""message_trace handles file_list=None gracefully."""
|
||||
mock_db = MagicMock()
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
@ -20,7 +20,7 @@ def test_validate_distance_function_rejects_unsupported_values():
|
||||
factory._validate_distance_function("dot_product")
|
||||
|
||||
|
||||
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch):
|
||||
def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = AlibabaCloudMySQLVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
@ -45,7 +45,7 @@ def test_factory_init_vector_uses_existing_index_struct_class_prefix(monkeypatch
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "existing_collection"
|
||||
|
||||
|
||||
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch):
|
||||
def test_factory_init_vector_generates_collection_name_when_index_struct_is_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = AlibabaCloudMySQLVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-2",
|
||||
|
||||
@ -83,7 +83,7 @@ def test_get_type_is_analyticdb():
|
||||
assert vector.get_type() == "analyticdb"
|
||||
|
||||
|
||||
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
|
||||
def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = AnalyticdbVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
|
||||
|
||||
@ -109,7 +109,7 @@ def test_factory_builds_openapi_config_when_host_is_missing(monkeypatch):
|
||||
assert dataset.index_struct is not None
|
||||
|
||||
|
||||
def test_factory_builds_sql_config_when_host_is_present(monkeypatch):
|
||||
def test_factory_builds_sql_config_when_host_is_present(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = AnalyticdbVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-2", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
|
||||
|
||||
@ -24,7 +24,7 @@ def _request_class(name: str):
|
||||
return _Request
|
||||
|
||||
|
||||
def _install_openapi_stubs(monkeypatch):
|
||||
def _install_openapi_stubs(monkeypatch: pytest.MonkeyPatch):
|
||||
gpdb_package = types.ModuleType("alibabacloud_gpdb20160503")
|
||||
gpdb_package.__path__ = []
|
||||
gpdb_models = types.ModuleType("alibabacloud_gpdb20160503.models")
|
||||
@ -130,7 +130,7 @@ def test_openapi_config_to_client_params():
|
||||
assert params["read_timeout"] == 60000
|
||||
|
||||
|
||||
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
|
||||
def test_init_creates_openapi_client_and_runs_initialize(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
initialize_mock = MagicMock()
|
||||
monkeypatch.setattr(openapi_module.AnalyticdbVectorOpenAPI, "_initialize", initialize_mock)
|
||||
@ -145,7 +145,7 @@ def test_init_creates_openapi_client_and_runs_initialize(monkeypatch):
|
||||
initialize_mock.assert_called_once_with()
|
||||
|
||||
|
||||
def test_initialize_skips_when_cached(monkeypatch):
|
||||
def test_initialize_skips_when_cached(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -164,7 +164,7 @@ def test_initialize_skips_when_cached(monkeypatch):
|
||||
vector._create_namespace_if_not_exists.assert_not_called()
|
||||
|
||||
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -184,7 +184,7 @@ def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
openapi_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_initialize_vector_database_calls_openapi_client(monkeypatch):
|
||||
def test_initialize_vector_database_calls_openapi_client(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
@ -199,7 +199,7 @@ def test_initialize_vector_database_calls_openapi_client(monkeypatch):
|
||||
assert request.manager_account_password == "password"
|
||||
|
||||
|
||||
def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
|
||||
def test_create_namespace_creates_when_namespace_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
@ -211,7 +211,7 @@ def test_create_namespace_creates_when_namespace_not_found(monkeypatch):
|
||||
vector._client.create_namespace.assert_called_once()
|
||||
|
||||
|
||||
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
|
||||
def test_create_namespace_raises_on_unexpected_api_error(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
@ -222,7 +222,7 @@ def test_create_namespace_raises_on_unexpected_api_error(monkeypatch):
|
||||
vector._create_namespace_if_not_exists()
|
||||
|
||||
|
||||
def test_create_namespace_noop_when_namespace_exists(monkeypatch):
|
||||
def test_create_namespace_noop_when_namespace_exists(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector.config = _config()
|
||||
@ -234,7 +234,7 @@ def test_create_namespace_noop_when_namespace_exists(monkeypatch):
|
||||
vector._client.create_namespace.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
|
||||
def test_create_collection_if_not_exists_creates_when_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -255,7 +255,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch):
|
||||
openapi_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
|
||||
def test_create_collection_if_not_exists_skips_when_cached(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -274,7 +274,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch):
|
||||
vector._client.create_collection.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
|
||||
def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch: pytest.MonkeyPatch):
|
||||
stubs = _install_openapi_stubs(monkeypatch)
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -293,7 +293,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch):
|
||||
vector.create_collection_if_not_exists(embedding_dimension=512)
|
||||
|
||||
|
||||
def test_openapi_add_delete_and_search_methods(monkeypatch):
|
||||
def test_openapi_add_delete_and_search_methods(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -348,7 +348,7 @@ def test_openapi_add_delete_and_search_methods(monkeypatch):
|
||||
assert docs_by_text[0].page_content == "high"
|
||||
|
||||
|
||||
def test_text_exists_returns_false_when_matches_empty(monkeypatch):
|
||||
def test_text_exists_returns_false_when_matches_empty(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -361,7 +361,7 @@ def test_text_exists_returns_false_when_matches_empty(monkeypatch):
|
||||
assert vector.text_exists("missing-id") is False
|
||||
|
||||
|
||||
def test_openapi_delete_success(monkeypatch):
|
||||
def test_openapi_delete_success(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
@ -372,7 +372,7 @@ def test_openapi_delete_success(monkeypatch):
|
||||
vector._client.delete_collection.assert_called_once()
|
||||
|
||||
|
||||
def test_openapi_delete_propagates_errors(monkeypatch):
|
||||
def test_openapi_delete_propagates_errors(monkeypatch: pytest.MonkeyPatch):
|
||||
_install_openapi_stubs(monkeypatch)
|
||||
vector = AnalyticdbVectorOpenAPI.__new__(AnalyticdbVectorOpenAPI)
|
||||
vector._collection_name = "collection_1"
|
||||
|
||||
@ -53,7 +53,7 @@ def test_sql_config_rejects_min_connection_greater_than_max_connection():
|
||||
AnalyticdbVectorBySqlConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_initialize_skips_when_cache_exists(monkeypatch):
|
||||
def test_initialize_skips_when_cache_exists(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -70,7 +70,7 @@ def test_initialize_skips_when_cache_exists(monkeypatch):
|
||||
vector._initialize_vector_database.assert_not_called()
|
||||
|
||||
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
def test_initialize_runs_when_cache_is_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -88,7 +88,7 @@ def test_initialize_runs_when_cache_is_missing(monkeypatch):
|
||||
sql_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch):
|
||||
def test_create_connection_pool_uses_psycopg2_pool(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
@ -119,7 +119,7 @@ def test_get_cursor_context_manager_handles_connection_lifecycle():
|
||||
pool.putconn.assert_called_once_with(connection)
|
||||
|
||||
|
||||
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch):
|
||||
def test_add_texts_inserts_only_documents_with_metadata(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.table_name = "dify.collection"
|
||||
|
||||
@ -273,7 +273,7 @@ def test_delete_drops_table():
|
||||
cursor.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch):
|
||||
def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypatch: pytest.MonkeyPatch):
|
||||
config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
created_pool = MagicMock()
|
||||
|
||||
@ -288,7 +288,7 @@ def test_init_normalizes_collection_name_and_creates_pool_when_missing(monkeypat
|
||||
assert vector.pool is created_pool
|
||||
|
||||
|
||||
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch):
|
||||
def test_initialize_vector_database_handles_existing_database_and_search_config(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
@ -326,7 +326,7 @@ def test_initialize_vector_database_handles_existing_database_and_search_config(
|
||||
assert any("CREATE SCHEMA IF NOT EXISTS dify" in call.args[0] for call in worker_cursor.execute.call_args_list)
|
||||
|
||||
|
||||
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch):
|
||||
def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector.databaseName = "knowledgebase"
|
||||
@ -353,7 +353,7 @@ def test_initialize_vector_database_raises_runtime_error_when_zhparser_fails(mon
|
||||
worker_connection.rollback.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch):
|
||||
def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector._collection_name = "collection"
|
||||
@ -381,7 +381,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp
|
||||
sql_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch):
|
||||
def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = AnalyticdbVectorBySql.__new__(AnalyticdbVectorBySql)
|
||||
vector.config = AnalyticdbVectorBySqlConfig(**_config_values())
|
||||
vector._collection_name = "collection"
|
||||
|
||||
@ -121,7 +121,7 @@ def _build_fake_pymochow_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baidu_module(monkeypatch):
|
||||
def baidu_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_pymochow_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
import dify_vdb_baidu.baidu_vector as module
|
||||
@ -254,7 +254,7 @@ def test_search_methods_delegate_to_database_table(baidu_module):
|
||||
assert vector._get_search_res.call_count == 2
|
||||
|
||||
|
||||
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch):
|
||||
def test_factory_initializes_collection_name_and_index_struct(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = baidu_module.BaiduVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1", index_struct_dict=None, index_struct=None)
|
||||
monkeypatch.setattr(baidu_module.Dataset, "gen_collection_name_by_id", lambda _id: "AUTO_COLLECTION")
|
||||
@ -279,7 +279,7 @@ def test_factory_initializes_collection_name_and_index_struct(baidu_module, monk
|
||||
assert dataset.index_struct is not None
|
||||
|
||||
|
||||
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch):
|
||||
def test_init_get_type_to_index_struct_and_create_delegate(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
init_client = MagicMock(return_value="client")
|
||||
init_database = MagicMock(return_value="database")
|
||||
monkeypatch.setattr(baidu_module.BaiduVector, "_init_client", init_client)
|
||||
@ -372,7 +372,7 @@ def test_get_search_result_handles_invalid_metadata_json(baidu_module):
|
||||
assert "document_id" not in docs[0].metadata
|
||||
|
||||
|
||||
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch):
|
||||
def test_init_client_constructs_configuration_and_client(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
credentials = MagicMock(return_value="credentials")
|
||||
configuration = MagicMock(return_value="configuration")
|
||||
client_cls = MagicMock(return_value="client")
|
||||
@ -411,7 +411,7 @@ def test_init_database_raises_for_unknown_create_database_error(baidu_module):
|
||||
vector._init_database()
|
||||
|
||||
|
||||
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch):
|
||||
def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = SimpleNamespace(
|
||||
@ -460,7 +460,7 @@ def test_create_table_handles_cache_and_validation_paths(baidu_module, monkeypat
|
||||
vector._wait_for_index_ready.assert_called_once_with(table, 3600)
|
||||
|
||||
|
||||
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch):
|
||||
def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._db = MagicMock()
|
||||
@ -493,7 +493,7 @@ def test_create_table_raises_for_invalid_index_or_metric(baidu_module, monkeypat
|
||||
vector._create_table(3)
|
||||
|
||||
|
||||
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch):
|
||||
def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = baidu_module.BaiduVector.__new__(baidu_module.BaiduVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = SimpleNamespace(
|
||||
@ -524,7 +524,9 @@ def test_create_table_raises_timeout_if_table_never_becomes_normal(baidu_module,
|
||||
vector._create_table(3)
|
||||
|
||||
|
||||
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(baidu_module, monkeypatch):
|
||||
def test_factory_uses_existing_collection_prefix_when_index_struct_exists(
|
||||
baidu_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
factory = baidu_module.BaiduVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -44,7 +44,7 @@ def _build_fake_chroma_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_module(monkeypatch):
|
||||
def chroma_module(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_chroma = _build_fake_chroma_modules()
|
||||
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
|
||||
import dify_vdb_chroma.chroma_vector as module
|
||||
@ -73,7 +73,7 @@ def test_chroma_config_to_params_builds_expected_payload(chroma_module):
|
||||
assert params["settings"].chroma_client_auth_credentials == "credentials"
|
||||
|
||||
|
||||
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch):
|
||||
def test_create_collection_uses_redis_lock_and_cache(chroma_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -173,7 +173,7 @@ def test_search_by_full_text_returns_empty_list(chroma_module):
|
||||
assert vector.search_by_full_text("query") == []
|
||||
|
||||
|
||||
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch):
|
||||
def test_factory_init_vector_uses_existing_or_generated_collection(chroma_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = chroma_module.ChromaVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "EXISTING"}}, index_struct=None
|
||||
|
||||
@ -45,7 +45,7 @@ def _build_fake_clickzetta_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clickzetta_module(monkeypatch):
|
||||
def clickzetta_module(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
|
||||
import dify_vdb_clickzetta.clickzetta_vector as module
|
||||
|
||||
@ -218,7 +218,7 @@ def test_search_by_like_returns_documents_with_default_score(clickzetta_module):
|
||||
assert docs[0].metadata["score"] == 0.5
|
||||
|
||||
|
||||
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
|
||||
def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = clickzetta_module.ClickzettaVectorFactory()
|
||||
dataset = SimpleNamespace(id="dataset-1")
|
||||
|
||||
@ -243,7 +243,7 @@ def test_factory_initializes_clickzetta_vector(clickzetta_module, monkeypatch):
|
||||
assert vector_cls.call_args.kwargs["collection_name"] == "collection"
|
||||
|
||||
|
||||
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
clickzetta_module.ClickzettaConnectionPool._instance = None
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
|
||||
@ -255,7 +255,7 @@ def test_connection_pool_singleton_and_config_key(clickzetta_module, monkeypatch
|
||||
assert "username:instance:service:workspace:cluster:dify" in key
|
||||
|
||||
|
||||
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_create_connection_retries_and_configures(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
@ -274,7 +274,7 @@ def test_connection_pool_create_connection_retries_and_configures(clickzetta_mod
|
||||
pool._configure_connection.assert_called_once_with(connection)
|
||||
|
||||
|
||||
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_create_connection_raises_after_retries(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
@ -318,7 +318,7 @@ def test_connection_pool_configure_connection_swallows_errors(clickzetta_module)
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "_start_cleanup_thread", MagicMock())
|
||||
pool = clickzetta_module.ClickzettaConnectionPool()
|
||||
config = _config(clickzetta_module)
|
||||
@ -360,7 +360,7 @@ def test_connection_pool_get_return_cleanup_and_shutdown(clickzetta_module, monk
|
||||
assert pool._shutdown is True
|
||||
|
||||
|
||||
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch):
|
||||
def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
|
||||
pool._shutdown = False
|
||||
pool._cleanup_expired_connections = MagicMock(side_effect=lambda: setattr(pool, "_shutdown", True))
|
||||
@ -384,7 +384,7 @@ def test_connection_pool_start_cleanup_thread_runs_worker_once(clickzetta_module
|
||||
pool._cleanup_expired_connections.assert_called_once()
|
||||
|
||||
|
||||
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch):
|
||||
def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
pool.get_connection.return_value = "conn"
|
||||
monkeypatch.setattr(clickzetta_module.ClickzettaConnectionPool, "get_instance", MagicMock(return_value=pool))
|
||||
@ -405,7 +405,7 @@ def test_vector_init_connection_context_and_helpers(clickzetta_module, monkeypat
|
||||
assert vector._ensure_connection() == "conn"
|
||||
|
||||
|
||||
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch):
|
||||
def test_write_queue_initialization_worker_and_execute_write(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class _Thread:
|
||||
def __init__(self, target, daemon):
|
||||
self.target = target
|
||||
@ -579,7 +579,7 @@ def test_create_inverted_index_branches(clickzetta_module):
|
||||
vector._create_inverted_index(cursor)
|
||||
|
||||
|
||||
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch):
|
||||
def test_add_texts_batches_and_insert_batch_behaviors(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = clickzetta_module.ClickzettaVector.__new__(clickzetta_module.ClickzettaVector)
|
||||
vector._config = _config(clickzetta_module)
|
||||
vector._config.batch_size = 2
|
||||
@ -811,7 +811,7 @@ def test_clickzetta_pool_cleanup_and_shutdown_edge_paths(clickzetta_module):
|
||||
assert pool._shutdown is True
|
||||
|
||||
|
||||
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch):
|
||||
def test_clickzetta_pool_cleanup_thread_and_worker_exception_paths(clickzetta_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = clickzetta_module.ClickzettaConnectionPool.__new__(clickzetta_module.ClickzettaConnectionPool)
|
||||
pool._shutdown = False
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ def _build_fake_couchbase_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def couchbase_module(monkeypatch):
|
||||
def couchbase_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_couchbase_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -194,7 +194,7 @@ def test_init_sets_cluster_handles(couchbase_module):
|
||||
vector._cluster.wait_until_ready.assert_called_once()
|
||||
|
||||
|
||||
def test_create_and_create_collection_branches(couchbase_module, monkeypatch):
|
||||
def test_create_and_create_collection_branches(couchbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = couchbase_module.CouchbaseVector.__new__(couchbase_module.CouchbaseVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._client_config = _config(couchbase_module)
|
||||
@ -319,7 +319,7 @@ def test_search_methods_and_format_metadata(couchbase_module):
|
||||
assert vector._format_metadata({"metadata.a": 1, "plain": 2}) == {"a": 1, "plain": 2}
|
||||
|
||||
|
||||
def test_delete_collection_and_factory(couchbase_module, monkeypatch):
|
||||
def test_delete_collection_and_factory(couchbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = couchbase_module.CouchbaseVector("collection_1", _config(couchbase_module))
|
||||
scopes = [
|
||||
SimpleNamespace(collections=[SimpleNamespace(name="other")]),
|
||||
|
||||
@ -28,7 +28,7 @@ def _build_fake_elasticsearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def elasticsearch_ja_module(monkeypatch):
|
||||
def elasticsearch_ja_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -39,7 +39,7 @@ def elasticsearch_ja_module(monkeypatch):
|
||||
return importlib.reload(ja_module)
|
||||
|
||||
|
||||
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
|
||||
def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -57,7 +57,7 @@ def test_create_collection_cache_hit(elasticsearch_ja_module, monkeypatch):
|
||||
elasticsearch_ja_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch):
|
||||
def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -87,7 +87,7 @@ def test_create_collection_create_and_exists_paths(elasticsearch_ja_module, monk
|
||||
elasticsearch_ja_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch):
|
||||
def test_ja_factory_uses_existing_or_generated_collection(elasticsearch_ja_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = elasticsearch_ja_module.ElasticSearchJaVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -38,7 +38,7 @@ def _build_fake_elasticsearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def elasticsearch_module(monkeypatch):
|
||||
def elasticsearch_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -287,7 +287,7 @@ def test_search_by_vector_and_full_text(elasticsearch_module):
|
||||
assert "bool" in query
|
||||
|
||||
|
||||
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
|
||||
def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -331,7 +331,7 @@ def test_create_and_create_collection_paths(elasticsearch_module, monkeypatch):
|
||||
elasticsearch_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch):
|
||||
def test_elasticsearch_factory_branches(elasticsearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = elasticsearch_module.ElasticSearchVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -38,7 +38,7 @@ def _build_fake_hologres_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hologres_module(monkeypatch):
|
||||
def hologres_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_hologres_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -266,7 +266,7 @@ def test_delete_handles_existing_and_missing_tables(hologres_module):
|
||||
vector._client.drop_table.assert_called_once_with(vector.table_name)
|
||||
|
||||
|
||||
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch):
|
||||
def test_create_collection_returns_early_when_cache_hits(hologres_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
@ -281,7 +281,7 @@ def test_create_collection_returns_early_when_cache_hits(hologres_module, monkey
|
||||
hologres_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch):
|
||||
def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
@ -313,7 +313,7 @@ def test_create_collection_creates_table_and_indexes(hologres_module, monkeypatc
|
||||
hologres_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch):
|
||||
def test_create_collection_raises_when_table_never_becomes_ready(hologres_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = False
|
||||
@ -331,7 +331,7 @@ def test_create_collection_raises_when_table_never_becomes_ready(hologres_module
|
||||
hologres_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch):
|
||||
def test_hologres_factory_uses_existing_or_generated_collection(hologres_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = hologres_module.HologresVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -29,7 +29,7 @@ def _build_fake_elasticsearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def huawei_module(monkeypatch):
|
||||
def huawei_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -155,7 +155,7 @@ def test_search_by_vector_and_full_text(huawei_module):
|
||||
assert docs[0].page_content == "text-hit"
|
||||
|
||||
|
||||
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch):
|
||||
def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class FakeDocument:
|
||||
def __init__(self, page_content, vector, metadata):
|
||||
self.page_content = page_content
|
||||
@ -185,7 +185,7 @@ def test_search_by_vector_skips_hits_without_metadata(huawei_module, monkeypatch
|
||||
assert docs == []
|
||||
|
||||
|
||||
def test_create_and_create_collection_paths(huawei_module, monkeypatch):
|
||||
def test_create_and_create_collection_paths(huawei_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -218,7 +218,7 @@ def test_create_and_create_collection_paths(huawei_module, monkeypatch):
|
||||
huawei_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_huawei_factory_branches(huawei_module, monkeypatch):
|
||||
def test_huawei_factory_branches(huawei_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = huawei_module.HuaweiCloudVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -23,7 +23,7 @@ def _build_fake_iris_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def iris_module(monkeypatch):
|
||||
def iris_module(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
|
||||
|
||||
import dify_vdb_iris.iris_vector as module
|
||||
@ -249,7 +249,7 @@ def test_iris_vector_init_get_cursor_and_create(iris_module):
|
||||
vector._create_collection.assert_called_once_with(2)
|
||||
|
||||
|
||||
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
|
||||
def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch: pytest.MonkeyPatch):
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", _config(iris_module))
|
||||
|
||||
@ -297,7 +297,7 @@ def test_iris_vector_crud_and_vector_search(iris_module, monkeypatch):
|
||||
assert docs[0].metadata["score"] == pytest.approx(0.9)
|
||||
|
||||
|
||||
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
|
||||
def test_iris_vector_full_text_search_paths(iris_module, monkeypatch: pytest.MonkeyPatch):
|
||||
cfg = _config(iris_module, IRIS_TEXT_INDEX=True)
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", cfg)
|
||||
@ -344,7 +344,7 @@ def test_iris_vector_full_text_search_paths(iris_module, monkeypatch):
|
||||
assert vector_like.search_by_full_text("100%", top_k=1) == []
|
||||
|
||||
|
||||
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch):
|
||||
def test_iris_vector_delete_create_collection_and_factory(iris_module, monkeypatch: pytest.MonkeyPatch):
|
||||
with patch.object(iris_module, "get_iris_pool", return_value=MagicMock()):
|
||||
vector = iris_module.IrisVector("collection", _config(iris_module, IRIS_TEXT_INDEX=True))
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ def _build_fake_opensearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lindorm_module(monkeypatch):
|
||||
def lindorm_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -100,7 +100,7 @@ def test_to_opensearch_params_and_init(lindorm_module):
|
||||
assert vector_ugc._routing == "route"
|
||||
|
||||
|
||||
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch):
|
||||
def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = lindorm_module.LindormVectorStore(
|
||||
"collection", _config(lindorm_module), using_ugc=True, routing_value="route"
|
||||
)
|
||||
@ -301,7 +301,7 @@ def test_search_by_full_text_success_and_error(lindorm_module):
|
||||
vector.search_by_full_text("hello")
|
||||
|
||||
|
||||
def test_create_collection_paths(lindorm_module, monkeypatch):
|
||||
def test_create_collection_paths(lindorm_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = lindorm_module.LindormVectorStore("collection", _config(lindorm_module), using_ugc=False)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
@ -331,7 +331,7 @@ def test_create_collection_paths(lindorm_module, monkeypatch):
|
||||
vector._client.indices.create.assert_not_called()
|
||||
|
||||
|
||||
def test_lindorm_factory_branches(lindorm_module, monkeypatch):
|
||||
def test_lindorm_factory_branches(lindorm_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = lindorm_module.LindormVectorStoreFactory()
|
||||
|
||||
monkeypatch.setattr(lindorm_module.dify_config, "LINDORM_URL", "http://localhost:9200")
|
||||
|
||||
@ -32,7 +32,7 @@ def _build_fake_mo_vector_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def matrixone_module(monkeypatch):
|
||||
def matrixone_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_mo_vector_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -70,7 +70,7 @@ def test_matrixone_config_validation(matrixone_module, field, value, message):
|
||||
matrixone_module.MatrixoneConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch):
|
||||
def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -86,7 +86,7 @@ def test_get_client_creates_full_text_index_when_cache_misses(matrixone_module,
|
||||
matrixone_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch):
|
||||
def test_get_client_skips_index_creation_when_cache_hits(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -146,7 +146,7 @@ def test_get_type_and_create_delegate_to_add_texts(matrixone_module):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch):
|
||||
def test_get_client_handles_full_text_index_creation_error(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -165,7 +165,7 @@ def test_get_client_handles_full_text_index_creation_error(matrixone_module, mon
|
||||
matrixone_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch):
|
||||
def test_add_texts_generates_ids_and_inserts(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = matrixone_module.MatrixoneVector("collection_1", _valid_config(matrixone_module))
|
||||
vector.client = MagicMock()
|
||||
monkeypatch.setattr(matrixone_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
@ -224,7 +224,7 @@ def test_search_by_vector_builds_documents(matrixone_module):
|
||||
assert vector.client.query.call_args.kwargs["filter"] == {"document_id": {"$in": ["d-1"]}}
|
||||
|
||||
|
||||
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch):
|
||||
def test_matrixone_factory_uses_existing_or_generated_collection(matrixone_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = matrixone_module.MatrixoneVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -99,7 +99,7 @@ def _build_fake_pymilvus_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def milvus_module(monkeypatch):
|
||||
def milvus_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_pymilvus_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -327,7 +327,7 @@ def test_process_search_results_and_search_methods(milvus_module):
|
||||
assert "document_id" in vector._client.search.call_args.kwargs["filter"]
|
||||
|
||||
|
||||
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch):
|
||||
def test_create_collection_cache_and_existing_collection(milvus_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -351,7 +351,7 @@ def test_create_collection_cache_and_existing_collection(milvus_module, monkeypa
|
||||
milvus_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch):
|
||||
def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -385,7 +385,7 @@ def test_create_collection_builds_schema_and_indexes(milvus_module, monkeypatch)
|
||||
assert call_kwargs["consistency_level"] == "Session"
|
||||
|
||||
|
||||
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch):
|
||||
def test_factory_initializes_milvus_vector(milvus_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = milvus_module.MilvusVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -38,7 +38,7 @@ def _build_fake_clickhouse_connect_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def myscale_module(monkeypatch):
|
||||
def myscale_module(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_module = _build_fake_clickhouse_connect_module()
|
||||
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
|
||||
|
||||
@ -90,7 +90,7 @@ def test_delete_by_ids_short_circuits_on_empty_list(myscale_module):
|
||||
vector._client.command.assert_not_called()
|
||||
|
||||
|
||||
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch):
|
||||
def test_factory_initializes_lower_case_collection_name(myscale_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = myscale_module.MyScaleVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
@ -160,7 +160,7 @@ def test_create_collection_builds_expected_sql(myscale_module):
|
||||
assert "INDEX text_idx text TYPE fts('tokenizer=unicode')" in sql
|
||||
|
||||
|
||||
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch):
|
||||
def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = myscale_module.MyScaleVector("collection_1", _config(myscale_module))
|
||||
monkeypatch.setattr(myscale_module.uuid, "uuid4", lambda: "generated-uuid")
|
||||
docs = [
|
||||
|
||||
@ -53,7 +53,7 @@ def _build_fake_pyobvector_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oceanbase_module(monkeypatch):
|
||||
def oceanbase_module(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module())
|
||||
|
||||
import dify_vdb_oceanbase.oceanbase_vector as module
|
||||
@ -208,7 +208,7 @@ def test_create_delegates_to_collection_and_insert(oceanbase_module):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch):
|
||||
def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -234,7 +234,7 @@ def test_create_collection_cache_and_existing_table_short_circuits(oceanbase_mod
|
||||
vector.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch):
|
||||
def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -271,7 +271,7 @@ def test_create_collection_happy_path_with_hybrid_and_index(oceanbase_module, mo
|
||||
oceanbase_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_error_paths(oceanbase_module, monkeypatch):
|
||||
def test_create_collection_error_paths(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -308,7 +308,7 @@ def test_create_collection_error_paths(oceanbase_module, monkeypatch):
|
||||
vector._create_collection()
|
||||
|
||||
|
||||
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch):
|
||||
def test_create_collection_fulltext_and_metadata_index_exceptions(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -517,7 +517,7 @@ def test_delete_success_and_exception(oceanbase_module):
|
||||
vector.delete()
|
||||
|
||||
|
||||
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch):
|
||||
def test_oceanbase_factory_uses_existing_or_generated_collection(oceanbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = oceanbase_module.OceanBaseVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -37,7 +37,7 @@ def _build_fake_psycopg2_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opengauss_module(monkeypatch):
|
||||
def opengauss_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_psycopg2_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -88,7 +88,7 @@ def test_opengauss_config_validation_rejects_min_greater_than_max(opengauss_modu
|
||||
opengauss_module.OpenGaussConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
|
||||
def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -99,7 +99,7 @@ def test_init_sets_table_name_and_vector_type(opengauss_module, monkeypatch):
|
||||
assert vector.pool is pool
|
||||
|
||||
|
||||
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
|
||||
def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -126,7 +126,7 @@ def test_create_index_with_pq_executes_pq_sql(opengauss_module, monkeypatch):
|
||||
opengauss_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch):
|
||||
def test_create_index_skips_index_sql_for_large_dimension(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -158,7 +158,7 @@ def test_search_by_vector_validates_top_k(opengauss_module):
|
||||
vector.search_by_vector([0.1, 0.2], top_k=0)
|
||||
|
||||
|
||||
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch):
|
||||
def test_delete_by_ids_short_circuits_with_empty_input(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
@ -200,7 +200,7 @@ def test_create_calls_collection_insert_and_index(opengauss_module):
|
||||
vector._create_index.assert_called_once_with(2)
|
||||
|
||||
|
||||
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
|
||||
def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -220,7 +220,7 @@ def test_create_index_returns_early_on_cache_hit(opengauss_module, monkeypatch):
|
||||
opengauss_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch):
|
||||
def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -245,7 +245,7 @@ def test_create_index_without_pq_executes_standard_index_sql(opengauss_module, m
|
||||
assert any("embedding_cosine_embedding_collection_1_idx" in query for query in sql)
|
||||
|
||||
|
||||
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch):
|
||||
def test_add_texts_uses_execute_values(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
vector = opengauss_module.OpenGauss("collection_1", _config(opengauss_module))
|
||||
@ -342,7 +342,7 @@ def test_search_by_full_text_validates_top_k(opengauss_module):
|
||||
vector.search_by_full_text("query", top_k=0)
|
||||
|
||||
|
||||
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
|
||||
def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(opengauss_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
lock = MagicMock()
|
||||
@ -370,7 +370,7 @@ def test_create_collection_cache_and_create_path(opengauss_module, monkeypatch):
|
||||
opengauss_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch):
|
||||
def test_opengauss_factory_uses_existing_or_generated_collection(opengauss_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = opengauss_module.OpenGaussFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -59,7 +59,7 @@ def _build_fake_opensearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opensearch_module(monkeypatch):
|
||||
def opensearch_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -95,7 +95,7 @@ class TestOpenSearchConfig:
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] == ("admin", "password")
|
||||
|
||||
def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch):
|
||||
def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class _Session:
|
||||
def get_credentials(self):
|
||||
return "creds"
|
||||
|
||||
@ -58,7 +58,7 @@ def _build_fake_opensearch_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def opensearch_module(monkeypatch):
|
||||
def opensearch_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -116,7 +116,7 @@ def test_config_validation_for_aws_auth_and_https_fields(opensearch_module):
|
||||
opensearch_module.OpenSearchConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch):
|
||||
def test_create_aws_managed_iam_auth(opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class _Session:
|
||||
def get_credentials(self):
|
||||
return "creds"
|
||||
@ -167,7 +167,7 @@ def test_init_and_create_delegate_calls(opensearch_module):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch):
|
||||
def test_add_texts_supports_regular_and_aoss_clients(opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = opensearch_module.OpenSearchVector("Collection_1", _config(opensearch_module, aws_service="es"))
|
||||
docs = [
|
||||
Document(page_content="a", metadata={"doc_id": "1"}),
|
||||
@ -308,7 +308,7 @@ def test_search_by_full_text_and_filters(opensearch_module):
|
||||
assert query["query"]["bool"]["filter"] == [{"terms": {"metadata.document_id": ["d-1"]}}]
|
||||
|
||||
|
||||
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch):
|
||||
def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -331,7 +331,7 @@ def test_create_collection_cache_and_create_path(opensearch_module, monkeypatch)
|
||||
opensearch_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch):
|
||||
def test_opensearch_factory_initializes_expected_collection_name(opensearch_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = opensearch_module.OpenSearchVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -51,7 +51,7 @@ def _connection_with_cursor(cursor):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oracle_module(monkeypatch):
|
||||
def oracle_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_oracle_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -94,7 +94,7 @@ def test_oracle_config_validation_autonomous_requirements(oracle_module):
|
||||
)
|
||||
|
||||
|
||||
def test_init_and_get_type(oracle_module, monkeypatch):
|
||||
def test_init_and_get_type(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(oracle_module.oracledb, "create_pool", MagicMock(return_value=pool))
|
||||
vector = oracle_module.OracleVector("collection_1", _config(oracle_module))
|
||||
@ -139,7 +139,7 @@ def test_numpy_converters_and_type_handlers(oracle_module):
|
||||
assert out_float64.dtype == numpy.float64
|
||||
|
||||
|
||||
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch):
|
||||
def test_get_connection_supports_standard_and_autonomous_paths(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
connect = MagicMock(return_value="connection")
|
||||
monkeypatch.setattr(oracle_module.oracledb, "connect", connect)
|
||||
|
||||
@ -173,7 +173,7 @@ def test_create_delegates_collection_and_insert(oracle_module):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch):
|
||||
def test_add_texts_inserts_and_logs_on_failures(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector.input_type_handler = MagicMock()
|
||||
@ -279,7 +279,7 @@ def _fake_nltk_module(*, missing_data=False):
|
||||
return nltk, nltk_corpus
|
||||
|
||||
|
||||
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch):
|
||||
def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
|
||||
@ -305,7 +305,7 @@ def test_search_by_full_text_chinese_and_english_paths(oracle_module, monkeypatc
|
||||
assert "doc_id_0" in en_params
|
||||
|
||||
|
||||
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch):
|
||||
def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = oracle_module.OracleVector.__new__(oracle_module.OracleVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector._get_connection = MagicMock()
|
||||
@ -320,7 +320,7 @@ def test_search_by_full_text_empty_query_and_missing_nltk(oracle_module, monkeyp
|
||||
vector.search_by_full_text("english query")
|
||||
|
||||
|
||||
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
|
||||
def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -346,7 +346,9 @@ def test_create_collection_cache_and_execute_path(oracle_module, monkeypatch):
|
||||
oracle_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(oracle_module, monkeypatch):
|
||||
def test_oracle_factory_init_vector_uses_existing_or_generated_collection(
|
||||
oracle_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
factory = oracle_module.OracleVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -79,7 +79,7 @@ def _patch_both(monkeypatch, module, calls, execute_results=None):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pgvecto_module(monkeypatch):
|
||||
def pgvecto_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_pgvecto_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -126,7 +126,7 @@ def test_collection_base_has_expected_annotations(pgvecto_module):
|
||||
assert {"id", "text", "meta", "vector"} <= set(annotations)
|
||||
|
||||
|
||||
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
@ -145,7 +145,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
session_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
@ -169,7 +169,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch):
|
||||
module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
runtime_calls = []
|
||||
@ -241,7 +241,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch):
|
||||
assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls)
|
||||
|
||||
|
||||
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
init_calls = []
|
||||
monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine"))
|
||||
@ -313,7 +313,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch):
|
||||
assert vector.search_by_full_text("hello") == []
|
||||
|
||||
|
||||
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch):
|
||||
def test_factory_uses_existing_or_generated_collection(pgvecto_module, monkeypatch: pytest.MonkeyPatch):
|
||||
module, _ = pgvecto_module
|
||||
factory = module.PGVectoRSFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
|
||||
@ -336,7 +336,7 @@ def test_create_delegates_collection_creation_and_insert():
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch):
|
||||
def test_add_texts_uses_execute_values_and_returns_ids(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
|
||||
@ -387,7 +387,7 @@ def test_text_get_and_delete_methods():
|
||||
assert any("DROP TABLE IF EXISTS embedding_collection_1" in sql for sql in executed_sql)
|
||||
|
||||
|
||||
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch):
|
||||
def test_delete_by_ids_handles_empty_undefined_table_and_generic_exception(monkeypatch: pytest.MonkeyPatch):
|
||||
vector = PGVector.__new__(PGVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
cursor = MagicMock()
|
||||
@ -464,7 +464,7 @@ def test_search_by_full_text_branches_for_bigm_and_standard():
|
||||
assert "bigm_similarity" in cursor.execute.call_args_list[1].args[0]
|
||||
|
||||
|
||||
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch):
|
||||
def test_pgvector_factory_initializes_expected_collection_name(monkeypatch: pytest.MonkeyPatch):
|
||||
factory = pgvector_module.PGVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -121,7 +121,7 @@ def _build_fake_qdrant_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_module(monkeypatch):
|
||||
def qdrant_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_qdrant_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -170,7 +170,7 @@ def test_init_and_basic_behaviour(qdrant_module):
|
||||
vector.add_texts.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_and_add_texts(qdrant_module, monkeypatch):
|
||||
def test_create_collection_and_add_texts(qdrant_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = qdrant_module.QdrantVector("collection_1", "group-1", _config(qdrant_module))
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -288,7 +288,7 @@ def test_search_and_helper_methods(qdrant_module):
|
||||
assert doc.page_content == "doc"
|
||||
|
||||
|
||||
def test_qdrant_factory_paths(qdrant_module, monkeypatch):
|
||||
def test_qdrant_factory_paths(qdrant_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = qdrant_module.QdrantVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -59,7 +59,7 @@ def _patch_both(monkeypatch, module, session):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def relyt_module(monkeypatch):
|
||||
def relyt_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_relyt_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -97,7 +97,7 @@ def test_relyt_config_validation(relyt_module, field, value, message):
|
||||
relyt_module.RelytConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
|
||||
def test_init_get_type_and_create_delegate(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
engine = MagicMock()
|
||||
monkeypatch.setattr(relyt_module, "create_engine", MagicMock(return_value=engine))
|
||||
vector = relyt_module.RelytVector("collection_1", _config(relyt_module), group_id="group-1")
|
||||
@ -114,7 +114,7 @@ def test_init_get_type_and_create_delegate(relyt_module, monkeypatch):
|
||||
vector.add_texts.assert_called_once_with(docs, [[0.1, 0.2]])
|
||||
|
||||
|
||||
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
||||
def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -142,7 +142,7 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch):
|
||||
relyt_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch):
|
||||
def test_add_texts_and_metadata_queries(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._group_id = "group-1"
|
||||
@ -212,7 +212,7 @@ def test_delete_by_metadata_field_calls_delete_by_uuids(relyt_module):
|
||||
|
||||
|
||||
# 3. delete_by_ids translates to uuids
|
||||
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
|
||||
def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
@ -225,7 +225,7 @@ def test_delete_by_ids_translates_to_uuids(relyt_module, monkeypatch):
|
||||
|
||||
|
||||
# 4. text_exists True
|
||||
def test_text_exists_true(relyt_module, monkeypatch):
|
||||
def test_text_exists_true(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
@ -236,7 +236,7 @@ def test_text_exists_true(relyt_module, monkeypatch):
|
||||
|
||||
|
||||
# 5. text_exists False
|
||||
def test_text_exists_false(relyt_module, monkeypatch):
|
||||
def test_text_exists_false(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
@ -284,7 +284,7 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module):
|
||||
|
||||
|
||||
# 8. delete commits session
|
||||
def test_delete_drops_table(relyt_module, monkeypatch):
|
||||
def test_delete_drops_table(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector.client = MagicMock()
|
||||
@ -295,7 +295,7 @@ def test_delete_drops_table(relyt_module, monkeypatch):
|
||||
session.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch):
|
||||
def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = relyt_module.RelytVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -77,7 +77,7 @@ def _build_fake_tablestore_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tablestore_module(monkeypatch):
|
||||
def tablestore_module(monkeypatch: pytest.MonkeyPatch):
|
||||
fake_module = _build_fake_tablestore_module()
|
||||
monkeypatch.setitem(sys.modules, "tablestore", fake_module)
|
||||
|
||||
@ -177,7 +177,7 @@ def test_get_by_ids_text_exists_delete_and_wrappers(tablestore_module):
|
||||
vector._delete_table_if_exist.assert_called_once()
|
||||
|
||||
|
||||
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch):
|
||||
def test_create_collection_and_table_index_lifecycle(tablestore_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = tablestore_module.TableStoreVector("collection_1", _config(tablestore_module))
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
@ -289,7 +289,7 @@ def test_write_row_and_search_helpers(tablestore_module):
|
||||
assert "score" not in docs[0].metadata
|
||||
|
||||
|
||||
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch):
|
||||
def test_tablestore_factory_uses_existing_or_generated_collection(tablestore_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = tablestore_module.TableStoreVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -136,7 +136,7 @@ def _build_fake_tencent_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tencent_module(monkeypatch):
|
||||
def tencent_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_tencent_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -187,7 +187,7 @@ def test_config_and_init_paths(tencent_module):
|
||||
assert vector._enable_hybrid_search is False
|
||||
|
||||
|
||||
def test_create_collection_branches(tencent_module, monkeypatch):
|
||||
def test_create_collection_branches(tencent_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = tencent_module.TencentVector("collection_1", _config(tencent_module))
|
||||
|
||||
lock = MagicMock()
|
||||
@ -279,7 +279,7 @@ def test_create_add_delete_and_search_behaviour(tencent_module):
|
||||
vector._client.drop_collection.assert_called_once()
|
||||
|
||||
|
||||
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch):
|
||||
def test_tencent_factory_existing_and_generated_collection(tencent_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = tencent_module.TencentVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -46,7 +46,7 @@ def test_tidb_config_validation(tidb_module, field, value, message):
|
||||
tidb_module.TiDBVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
|
||||
def test_init_get_type_and_distance_func(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(tidb_module, "create_engine", MagicMock(return_value="engine"))
|
||||
|
||||
vector = tidb_module.TiDBVector("collection_1", _config(tidb_module), distance_func="L2")
|
||||
@ -63,7 +63,7 @@ def test_init_get_type_and_distance_func(tidb_module, monkeypatch):
|
||||
assert vector._get_distance_func() == "VEC_COSINE_DISTANCE"
|
||||
|
||||
|
||||
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch):
|
||||
def test_table_builds_columns_with_tidb_vector_type(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
fake_tidb_vector = types.ModuleType("tidb_vector")
|
||||
fake_tidb_sqlalchemy = types.ModuleType("tidb_vector.sqlalchemy")
|
||||
|
||||
@ -107,7 +107,7 @@ def test_create_calls_collection_and_add_texts(tidb_module):
|
||||
assert vector._dimension == 2
|
||||
|
||||
|
||||
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
|
||||
def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -127,7 +127,7 @@ def test_create_collection_skips_when_cache_hit(tidb_module, monkeypatch):
|
||||
tidb_module.redis_client.set.assert_not_called()
|
||||
|
||||
|
||||
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch):
|
||||
def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -160,7 +160,7 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke
|
||||
tidb_module.redis_client.set.assert_called_once()
|
||||
|
||||
|
||||
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
|
||||
def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
class _InsertStmt:
|
||||
def __init__(self, table):
|
||||
self.table = table
|
||||
@ -198,7 +198,7 @@ def test_add_texts_batches_inserts_and_returns_ids(tidb_module, monkeypatch):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tidb_vector_with_session(tidb_module, monkeypatch):
|
||||
def tidb_vector_with_session(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector)
|
||||
vector._collection_name = "collection_1"
|
||||
vector._engine = MagicMock()
|
||||
@ -354,7 +354,7 @@ def test_delete_by_metadata_field_does_nothing_when_no_ids(tidb_module):
|
||||
|
||||
|
||||
# Test search_by_vector filters and scores
|
||||
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
||||
def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = [
|
||||
('{"doc_id":"id-1","document_id":"d-1"}', "text-1", 0.2),
|
||||
@ -392,7 +392,7 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch):
|
||||
|
||||
|
||||
# Test delete drops table
|
||||
def test_delete_drops_table(tidb_module, monkeypatch):
|
||||
def test_delete_drops_table(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
session = MagicMock()
|
||||
session.execute.return_value = None
|
||||
|
||||
@ -413,7 +413,7 @@ def test_delete_drops_table(tidb_module, monkeypatch):
|
||||
assert "DROP TABLE IF EXISTS collection_1" in drop_sql
|
||||
|
||||
|
||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch):
|
||||
def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = tidb_module.TiDBVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -36,7 +36,7 @@ def _build_fake_upstash_module():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upstash_module(monkeypatch):
|
||||
def upstash_module(monkeypatch: pytest.MonkeyPatch):
|
||||
# Remove patched modules if present
|
||||
for modname in ["upstash_vector", "dify_vdb_upstash.upstash_vector"]:
|
||||
if modname in sys.modules:
|
||||
@ -65,7 +65,7 @@ def test_upstash_config_validation(upstash_module, field, value, message):
|
||||
upstash_module.UpstashVectorConfig.model_validate(values)
|
||||
|
||||
|
||||
def test_init_get_type_and_dimension(upstash_module, monkeypatch):
|
||||
def test_init_get_type_and_dimension(upstash_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = upstash_module.UpstashVector("collection_1", _config(upstash_module))
|
||||
|
||||
assert vector.get_type() == upstash_module.VectorType.UPSTASH
|
||||
@ -162,7 +162,7 @@ def test_search_by_vector_filter_threshold_and_delete(upstash_module):
|
||||
vector.index.reset.assert_called_once()
|
||||
|
||||
|
||||
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch):
|
||||
def test_upstash_factory_uses_existing_or_generated_collection(upstash_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = upstash_module.UpstashVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -37,7 +37,7 @@ def _build_fake_psycopg2_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vastbase_module(monkeypatch):
|
||||
def vastbase_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_psycopg2_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -93,7 +93,7 @@ def test_vastbase_config_rejects_invalid_connection_window(vastbase_module):
|
||||
)
|
||||
|
||||
|
||||
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
|
||||
def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
pool = MagicMock()
|
||||
monkeypatch.setattr(vastbase_module.psycopg2.pool, "SimpleConnectionPool", MagicMock(return_value=pool))
|
||||
|
||||
@ -114,7 +114,7 @@ def test_init_and_get_cursor_context_manager(vastbase_module, monkeypatch):
|
||||
pool.putconn.assert_called_once_with(conn)
|
||||
|
||||
|
||||
def test_create_and_add_texts(vastbase_module, monkeypatch):
|
||||
def test_create_and_add_texts(vastbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
vector = vastbase_module.VastbaseVector.__new__(vastbase_module.VastbaseVector)
|
||||
vector.table_name = "embedding_collection_1"
|
||||
vector._create_collection = MagicMock()
|
||||
@ -205,7 +205,7 @@ def test_search_by_vector_and_full_text(vastbase_module):
|
||||
assert full_docs[0].page_content == "full-text"
|
||||
|
||||
|
||||
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch):
|
||||
def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -240,7 +240,7 @@ def test_create_collection_cache_and_dimension_branches(vastbase_module, monkeyp
|
||||
vastbase_module.redis_client.set.assert_called()
|
||||
|
||||
|
||||
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch):
|
||||
def test_vastbase_factory_uses_existing_or_generated_collection(vastbase_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = vastbase_module.VastbaseVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
|
||||
@ -79,7 +79,7 @@ def _build_fake_vikingdb_modules():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vikingdb_module(monkeypatch):
|
||||
def vikingdb_module(monkeypatch: pytest.MonkeyPatch):
|
||||
for name, module in _build_fake_vikingdb_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
@ -117,7 +117,7 @@ def test_init_get_type_and_has_checks(vikingdb_module):
|
||||
assert vector._has_index() is False
|
||||
|
||||
|
||||
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch):
|
||||
def test_create_collection_cache_and_creation_paths(vikingdb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
lock = MagicMock()
|
||||
lock.__enter__.return_value = None
|
||||
lock.__exit__.return_value = None
|
||||
@ -253,7 +253,7 @@ def test_delete_drops_index_and_collection_when_present(vikingdb_module):
|
||||
vector._client.drop_collection.assert_not_called()
|
||||
|
||||
|
||||
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch):
|
||||
def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, monkeypatch: pytest.MonkeyPatch):
|
||||
factory = vikingdb_module.VikingDBVectorFactory()
|
||||
dataset_with_index = SimpleNamespace(
|
||||
id="dataset-1",
|
||||
@ -293,7 +293,9 @@ def test_vikingdb_factory_validates_config_and_builds_vector(vikingdb_module, mo
|
||||
("VIKINGDB_SCHEME", "VIKINGDB_SCHEME should not be None"),
|
||||
],
|
||||
)
|
||||
def test_vikingdb_factory_raises_when_required_config_missing(vikingdb_module, monkeypatch, field, message):
|
||||
def test_vikingdb_factory_raises_when_required_config_missing(
|
||||
vikingdb_module, monkeypatch: pytest.MonkeyPatch, field, message
|
||||
):
|
||||
factory = vikingdb_module.VikingDBVectorFactory()
|
||||
dataset = SimpleNamespace(
|
||||
id="dataset-1", index_struct_dict={"vector_store": {"class_prefix": "existing"}}, index_struct=None
|
||||
|
||||
@ -45,7 +45,7 @@ dependencies = [
|
||||
|
||||
# Emerging: newer and fast-moving, use compatible pins
|
||||
"fastopenapi[flask]~=0.7.0",
|
||||
"graphon~=0.2.2",
|
||||
"graphon~=0.3.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"json-repair~=0.59.4",
|
||||
]
|
||||
@ -103,6 +103,7 @@ dify-trace-weave = { workspace = true }
|
||||
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
|
||||
package = false
|
||||
override-dependencies = [
|
||||
"litellm>=1.83.7",
|
||||
"pyarrow>=18.0.0",
|
||||
]
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
@ -13,6 +13,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreditPoolService:
|
||||
@staticmethod
|
||||
def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None:
|
||||
return session.scalar(
|
||||
select(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.limit(1)
|
||||
.with_for_update()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""create default credit pool for new tenant"""
|
||||
@ -59,31 +71,57 @@ class CreditPoolService:
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
) -> int:
|
||||
"""check and deduct credits, returns actual credits deducted"""
|
||||
|
||||
pool = cls.get_pool(tenant_id, pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
if pool.remaining_credits <= 0:
|
||||
raise QuotaExceededError("No credits remaining")
|
||||
|
||||
# deduct all remaining credits if less than required
|
||||
actual_credits = min(credits_required, pool.remaining_credits)
|
||||
"""Deduct exactly the requested credits or raise without mutating the pool."""
|
||||
if credits_required <= 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
|
||||
)
|
||||
session.execute(stmt)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
remaining_credits = pool.remaining_credits
|
||||
if remaining_credits <= 0:
|
||||
raise QuotaExceededError("No credits remaining")
|
||||
if remaining_credits < credits_required:
|
||||
raise QuotaExceededError("Insufficient credits remaining")
|
||||
|
||||
pool.quota_used += credits_required
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
return actual_credits
|
||||
return credits_required
|
||||
|
||||
@classmethod
|
||||
def deduct_credits_capped(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
) -> int:
|
||||
"""Deduct up to the available balance and return the actual deducted credits."""
|
||||
if credits_required <= 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type)
|
||||
return 0
|
||||
|
||||
deducted_credits = min(credits_required, pool.remaining_credits)
|
||||
if deducted_credits <= 0:
|
||||
return 0
|
||||
|
||||
pool.quota_used += deducted_credits
|
||||
return deducted_credits
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct capped credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
@ -107,15 +107,14 @@ class FileService:
|
||||
hash=hashlib.sha3_256(content).hexdigest(),
|
||||
source_url=source_url,
|
||||
)
|
||||
# The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
|
||||
# We can directly generate the `source_url` here before committing.
|
||||
if not upload_file.source_url:
|
||||
upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
session.add(upload_file)
|
||||
session.commit()
|
||||
|
||||
if not upload_file.source_url:
|
||||
upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
|
||||
49
api/services/recommend_app/category_order.py
Normal file
49
api/services/recommend_app/category_order.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Apply Redis-backed category ordering for DB-backed Explore apps."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Collection
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EXPLORE_APP_CATEGORY_ORDER_KEY_PREFIX = "explore:apps:category_order"
|
||||
|
||||
|
||||
def _category_order_key(language: str) -> str:
|
||||
return f"{EXPLORE_APP_CATEGORY_ORDER_KEY_PREFIX}:{language}"
|
||||
|
||||
|
||||
def get_explore_app_category_order(language: str) -> list[str]:
|
||||
try:
|
||||
raw_categories = redis_client.get(_category_order_key(language))
|
||||
except Exception:
|
||||
logger.exception("Failed to read explore app category order from Redis.")
|
||||
return []
|
||||
|
||||
if not raw_categories:
|
||||
return []
|
||||
|
||||
if isinstance(raw_categories, bytes):
|
||||
raw_categories = raw_categories.decode("utf-8")
|
||||
|
||||
try:
|
||||
categories: Any = json.loads(raw_categories)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
logger.warning("Invalid explore app category order payload for language %s.", language)
|
||||
return []
|
||||
|
||||
if not isinstance(categories, list):
|
||||
return []
|
||||
|
||||
return [category for category in categories if isinstance(category, str)]
|
||||
|
||||
|
||||
def order_categories(categories: Collection[str], language: str) -> list[str]:
|
||||
configured_order = get_explore_app_category_order(language)
|
||||
if configured_order:
|
||||
return configured_order
|
||||
|
||||
return sorted(categories)
|
||||
@ -6,6 +6,7 @@ from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, RecommendedApp
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.recommend_app.category_order import order_categories
|
||||
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
||||
@ -18,7 +19,7 @@ class RecommendedAppItemDict(TypedDict):
|
||||
copyright: Any
|
||||
privacy_policy: Any
|
||||
custom_disclaimer: str
|
||||
category: str
|
||||
categories: list[str]
|
||||
position: int
|
||||
is_listed: bool
|
||||
|
||||
@ -80,6 +81,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
if not site:
|
||||
continue
|
||||
|
||||
app_categories = recommended_app.categories or []
|
||||
recommended_app_result: RecommendedAppItemDict = {
|
||||
"id": recommended_app.id,
|
||||
"app": recommended_app.app,
|
||||
@ -88,15 +90,18 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
"copyright": site.copyright,
|
||||
"privacy_policy": site.privacy_policy,
|
||||
"custom_disclaimer": site.custom_disclaimer,
|
||||
"category": recommended_app.category,
|
||||
"categories": app_categories,
|
||||
"position": recommended_app.position,
|
||||
"is_listed": recommended_app.is_listed,
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
categories.add(recommended_app.category)
|
||||
categories.update(app_categories)
|
||||
|
||||
return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories))
|
||||
return RecommendedAppsResultDict(
|
||||
recommended_apps=recommended_apps_result,
|
||||
categories=order_categories(categories, language),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None:
|
||||
|
||||
@ -408,7 +408,7 @@ class BuiltinToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
|
||||
def set_default_provider(tenant_id: str, provider: str, id: str):
|
||||
"""
|
||||
set default provider
|
||||
"""
|
||||
@ -422,12 +422,11 @@ class BuiltinToolManageService:
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
# clear default provider (tenant-scoped: only one default per provider per workspace)
|
||||
session.execute(
|
||||
update(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.user_id == user_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.is_default.is_(True),
|
||||
)
|
||||
|
||||
@ -194,14 +194,15 @@ class VariableTruncator(BaseTruncator):
|
||||
|
||||
result: _PartResult[Any]
|
||||
# Apply type-specific truncation with target size
|
||||
if isinstance(segment, ArraySegment):
|
||||
result = self._truncate_array(segment.value, target_size)
|
||||
elif isinstance(segment, StringSegment):
|
||||
result = self._truncate_string(segment.value, target_size)
|
||||
elif isinstance(segment, ObjectSegment):
|
||||
result = self._truncate_object(segment.value, target_size)
|
||||
else:
|
||||
raise AssertionError("this should be unreachable.")
|
||||
match segment:
|
||||
case ArraySegment():
|
||||
result = self._truncate_array(segment.value, target_size)
|
||||
case StringSegment():
|
||||
result = self._truncate_string(segment.value, target_size)
|
||||
case ObjectSegment():
|
||||
result = self._truncate_object(segment.value, target_size)
|
||||
case _:
|
||||
raise AssertionError("this should be unreachable.")
|
||||
|
||||
return _PartResult(
|
||||
value=segment.model_copy(update={"value": result.value}),
|
||||
@ -219,40 +220,41 @@ class VariableTruncator(BaseTruncator):
|
||||
return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
|
||||
if depth > _MAX_DEPTH:
|
||||
raise MaxDepthExceededError()
|
||||
if isinstance(value, str):
|
||||
# Ideally, the size of strings should be calculated based on their utf-8 encoded length.
|
||||
# However, this adds complexity as we would need to compute encoded sizes consistently
|
||||
# throughout the code. Therefore, we approximate the size using the string's length.
|
||||
# Rough estimate: number of characters, plus 2 for quotes
|
||||
return len(value) + 2
|
||||
elif isinstance(value, (int, float)):
|
||||
return len(str(value))
|
||||
elif isinstance(value, bool):
|
||||
return 4 if value else 5 # "true" or "false"
|
||||
elif value is None:
|
||||
return 4 # "null"
|
||||
elif isinstance(value, list):
|
||||
# Size = sum of elements + separators + brackets
|
||||
total = 2 # "[]"
|
||||
for i, item in enumerate(value):
|
||||
if i > 0:
|
||||
total += 1 # ","
|
||||
total += VariableTruncator.calculate_json_size(item, depth=depth + 1)
|
||||
return total
|
||||
elif isinstance(value, dict):
|
||||
# Size = sum of keys + values + separators + brackets
|
||||
total = 2 # "{}"
|
||||
for index, key in enumerate(value.keys()):
|
||||
if index > 0:
|
||||
total += 1 # ","
|
||||
total += VariableTruncator.calculate_json_size(str(key), depth=depth + 1) # Key as string
|
||||
total += 1 # ":"
|
||||
total += VariableTruncator.calculate_json_size(value[key], depth=depth + 1)
|
||||
return total
|
||||
elif isinstance(value, File):
|
||||
return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
|
||||
else:
|
||||
raise UnknownTypeError(f"got unknown type {type(value)}")
|
||||
match value:
|
||||
case str():
|
||||
# Ideally, the size of strings should be calculated based on their utf-8 encoded length.
|
||||
# However, this adds complexity as we would need to compute encoded sizes consistently
|
||||
# throughout the code. Therefore, we approximate the size using the string's length.
|
||||
# Rough estimate: number of characters, plus 2 for quotes
|
||||
return len(value) + 2
|
||||
case bool():
|
||||
return 4 if value else 5 # "true" or "false"
|
||||
case int() | float():
|
||||
return len(str(value))
|
||||
case None:
|
||||
return 4 # "null"
|
||||
case list():
|
||||
# Size = sum of elements + separators + brackets
|
||||
total = 2 # "[]"
|
||||
for i, item in enumerate(value):
|
||||
if i > 0:
|
||||
total += 1 # ","
|
||||
total += VariableTruncator.calculate_json_size(item, depth=depth + 1)
|
||||
return total
|
||||
case dict():
|
||||
# Size = sum of keys + values + separators + brackets
|
||||
total = 2 # "{}"
|
||||
for index, key in enumerate(value.keys()):
|
||||
if index > 0:
|
||||
total += 1 # ","
|
||||
total += VariableTruncator.calculate_json_size(str(key), depth=depth + 1) # Key as string
|
||||
total += 1 # ":"
|
||||
total += VariableTruncator.calculate_json_size(value[key], depth=depth + 1)
|
||||
return total
|
||||
case File():
|
||||
return VariableTruncator.calculate_json_size(value.model_dump(), depth=depth + 1)
|
||||
case _:
|
||||
raise UnknownTypeError(f"got unknown type {type(value)}")
|
||||
|
||||
def _truncate_string(self, value: str, target_size: int) -> _PartResult[str]:
|
||||
if (size := self.calculate_json_size(value)) < target_size:
|
||||
@ -419,22 +421,23 @@ class VariableTruncator(BaseTruncator):
|
||||
target_size: int,
|
||||
) -> _PartResult[Any]:
|
||||
"""Truncate a value within an object to fit within budget."""
|
||||
if isinstance(val, UpdatedVariable):
|
||||
# TODO(Workflow): push UpdatedVariable normalization closer to its producer.
|
||||
return self._truncate_object(val.model_dump(), target_size)
|
||||
elif isinstance(val, str):
|
||||
return self._truncate_string(val, target_size)
|
||||
elif isinstance(val, list):
|
||||
return self._truncate_array(val, target_size)
|
||||
elif isinstance(val, dict):
|
||||
return self._truncate_object(val, target_size)
|
||||
elif isinstance(val, File):
|
||||
# File objects should not be truncated, return as-is
|
||||
return _PartResult(val, self.calculate_json_size(val), False)
|
||||
elif val is None or isinstance(val, (bool, int, float)):
|
||||
return _PartResult(val, self.calculate_json_size(val), False)
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
match val:
|
||||
case UpdatedVariable():
|
||||
# TODO(Workflow): push UpdatedVariable normalization closer to its producer.
|
||||
return self._truncate_object(val.model_dump(), target_size)
|
||||
case str():
|
||||
return self._truncate_string(val, target_size)
|
||||
case list():
|
||||
return self._truncate_array(val, target_size)
|
||||
case dict():
|
||||
return self._truncate_object(val, target_size)
|
||||
case File():
|
||||
# File objects should not be truncated, return as-is
|
||||
return _PartResult(val, self.calculate_json_size(val), False)
|
||||
case None | bool() | int() | float():
|
||||
return _PartResult(val, self.calculate_json_size(val), False)
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
||||
class DummyVariableTruncator(BaseTruncator):
|
||||
|
||||
@ -157,8 +157,8 @@ class DraftVarLoader(VariableLoader):
|
||||
# This approach reduces loading time by querying external systems concurrently.
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
offloaded_variables = executor.map(self._load_offloaded_variable, offloaded_draft_vars)
|
||||
for selector, variable in offloaded_variables:
|
||||
variable_by_selector[selector] = variable
|
||||
for selector, offloaded_variable in offloaded_variables:
|
||||
variable_by_selector[selector] = offloaded_variable
|
||||
|
||||
return list(variable_by_selector.values())
|
||||
|
||||
|
||||
@ -1251,7 +1251,7 @@ class WorkflowService:
|
||||
node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
|
||||
node = HumanInputNode(
|
||||
node_id=node_config["id"],
|
||||
config=node_data,
|
||||
data=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=DifyHumanInputNodeRuntime(run_context),
|
||||
|
||||
@ -13,7 +13,7 @@ from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import ConversationFromSource
|
||||
from models.enums import AppStatus, ConversationFromSource
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
@ -28,7 +28,7 @@ class TestChatMessageApiPermissions:
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.status = AppStatus.NORMAL
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
@ -78,7 +78,7 @@ class TestChatMessageApiPermissions:
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
@ -130,7 +130,7 @@ class TestChatMessageApiPermissions:
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
|
||||
@ -14,7 +14,7 @@ from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.enums import AppStatus, FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, MessageFeedback
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
@ -29,7 +29,7 @@ class TestFeedbackExportApi:
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.status = AppStatus.NORMAL
|
||||
app.name = "Test App"
|
||||
return app
|
||||
|
||||
@ -135,7 +135,7 @@ class TestFeedbackExportApi:
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
@ -167,7 +167,13 @@ class TestFeedbackExportApi:
|
||||
mock_export_feedbacks.assert_called_once()
|
||||
|
||||
def test_feedback_export_csv_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
sample_feedback_data,
|
||||
):
|
||||
"""Test feedback export in CSV format."""
|
||||
|
||||
@ -202,7 +208,13 @@ class TestFeedbackExportApi:
|
||||
assert "text/csv" in response.content_type
|
||||
|
||||
def test_feedback_export_json_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account, sample_feedback_data
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
sample_feedback_data,
|
||||
):
|
||||
"""Test feedback export in JSON format."""
|
||||
|
||||
@ -246,7 +258,7 @@ class TestFeedbackExportApi:
|
||||
assert "application/json" in response.content_type
|
||||
|
||||
def test_feedback_export_with_filters(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with various filters."""
|
||||
|
||||
@ -287,7 +299,7 @@ class TestFeedbackExportApi:
|
||||
)
|
||||
|
||||
def test_feedback_export_invalid_date_format(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with invalid date format."""
|
||||
|
||||
@ -312,7 +324,7 @@ class TestFeedbackExportApi:
|
||||
assert "Parameter validation error" in response_json["error"]
|
||||
|
||||
def test_feedback_export_server_error(
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch, mock_app_model, mock_account
|
||||
self, test_client: FlaskClient, auth_header, monkeypatch: pytest.MonkeyPatch, mock_app_model, mock_account
|
||||
):
|
||||
"""Test feedback export with server error."""
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import AppStatus
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@ -25,7 +26,7 @@ class TestModelConfigResourcePermissions:
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.status = AppStatus.NORMAL
|
||||
app.app_model_config_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@ -73,7 +74,7 @@ class TestModelConfigResourcePermissions:
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
@ -19,7 +21,7 @@ def _gen_var_stream() -> Generator[DatasourceMessage, None, None]:
|
||||
)
|
||||
|
||||
|
||||
def test_stream_node_events_accumulates_variables(mocker):
|
||||
def test_stream_node_events_accumulates_variables(mocker: MockerFixture):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_var_stream())
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
@ -44,7 +46,7 @@ class _GP:
|
||||
call_depth = 0
|
||||
|
||||
|
||||
def test_node_integration_minimal_stream(mocker):
|
||||
def test_node_integration_minimal_stream(mocker: MockerFixture):
|
||||
sys_d = {
|
||||
"sys": {
|
||||
"datasource_type": "online_document",
|
||||
@ -71,7 +73,7 @@ def test_node_integration_minimal_stream(mocker):
|
||||
|
||||
node = DatasourceNode(
|
||||
node_id="n",
|
||||
config=DatasourceNodeData(
|
||||
data=DatasourceNodeData(
|
||||
type="datasource",
|
||||
version="1",
|
||||
title="Datasource",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user