Merge remote-tracking branch 'origin/main'

# Conflicts:
#	api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py
This commit is contained in:
FFXN 2026-04-14 13:54:31 +08:00
commit 03660c19ef
466 changed files with 9399 additions and 6724 deletions

View File

@ -0,0 +1,79 @@
---
name: e2e-cucumber-playwright
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
---
# Dify E2E Cucumber + Playwright
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
## Scope
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
- Do not use this skill for backend test or API review tasks under `api/`.
## Read Order
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
2. Read only the files directly involved in the task:
- target `.feature` files under `e2e/features/`
- related step files under `e2e/features/step-definitions/`
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
## Local Rules
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
- Browser session behavior comes from `features/support/hooks.ts`:
- default: authenticated session with shared storage state
- `@unauthenticated`: clean browser context
- `@authenticated`: readability/selective-run tag only unless implementation changes
- `@fresh`: only for `e2e:full*` flows
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
## Workflow
1. Rebuild local context.
- Inspect the target feature area.
- Reuse an existing step when wording and behavior already match.
- Add a new step only for a genuinely new user action or assertion.
- Keep edits close to the current capability folder unless the step is broadly reusable.
2. Write behavior-first scenarios.
- Describe user-observable behavior, not DOM mechanics.
- Keep each scenario focused on one workflow or outcome.
- Keep scenarios independent and re-runnable.
3. Write step definitions in the local style.
- Keep one step to one user-visible action or one assertion.
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
- Scope locators to stable containers when the page has repeated elements.
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
4. Use Playwright in the local style.
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
- Use web-first `expect(...)` assertions.
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
5. Validate narrowly.
- Run the narrowest tagged scenario or flow that exercises the change.
- Run `pnpm -C e2e check`.
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
## Review Checklist
- Does the scenario describe behavior rather than implementation?
- Does it fit the current session model, tags, and `DifyWorld` usage?
- Should an existing step be reused instead of adding a new one?
- Are locators user-facing and assertions web-first?
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
- Does it document or implement behavior that differs from the real hooks or configuration?
Lead findings with correctness, flake risk, and architecture drift.
## References
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)

View File

@ -0,0 +1,4 @@
interface:
display_name: "E2E Cucumber + Playwright"
short_description: "Write and review Dify E2E scenarios."
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."

View File

@ -0,0 +1,93 @@
# Cucumber Best Practices For Dify E2E
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
Official sources:
- https://cucumber.io/docs/guides/10-minute-tutorial/
- https://cucumber.io/docs/cucumber/step-definitions/
- https://cucumber.io/docs/cucumber/cucumber-expressions/
## What Matters Most
### 1. Treat scenarios as executable specifications
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
Apply it like this:
- write what the user does and what should happen
- avoid UI-internal wording such as selector details, DOM structure, or component names
- keep language concrete enough that the scenario reads like living documentation
### 2. Keep scenarios focused
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
In Dify's suite, this means:
- one capability-focused scenario per feature path
- no long setup chains when existing bootstrap or reusable steps already cover them
- no hidden dependency on another scenario's side effects
### 3. Reuse steps, but only when behavior really matches
Good reuse reduces duplication. Bad reuse hides meaning.
Prefer reuse when:
- the user action is genuinely the same
- the expected outcome is genuinely the same
- the wording stays natural across features
Write a new step when:
- the behavior is materially different
- reusing the old wording would make the scenario misleading
- a supposedly generic step would become an implementation-detail wrapper
### 4. Prefer Cucumber Expressions
Use Cucumber Expressions for parameters unless regex is clearly necessary.
Common examples:
- `{string}` for labels, names, and visible text
- `{int}` for counts
- `{float}` for decimal values
- `{word}` only when the value is truly a single token
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
### 5. Keep step definitions thin and meaningful
Step definitions are glue between Gherkin and automation, not a second abstraction language.
For Dify:
- type `this` as `DifyWorld`
- use `async function`
- keep each step to one user-visible action or assertion
- rely on `DifyWorld` and existing support code for shared context
- avoid leaking cross-scenario state
### 6. Use tags intentionally
Tags should communicate run scope or session semantics, not become ad hoc metadata.
In Dify's current suite:
- capability tags group related scenarios
- `@unauthenticated` changes session behavior
- `@authenticated` is descriptive/selective, not a behavior switch by itself
- `@fresh` belongs to reset/full-install flows only
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
## Review Questions
- Does the scenario read like a real example of product behavior?
- Are the steps behavior-oriented instead of implementation-oriented?
- Is a reused step still truthful in this feature?
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
- Would a new reader understand the outcome without opening the step-definition file?

View File

@ -0,0 +1,96 @@
# Playwright Best Practices For Dify E2E
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
Official sources:
- https://playwright.dev/docs/best-practices
- https://playwright.dev/docs/locators
- https://playwright.dev/docs/test-assertions
- https://playwright.dev/docs/browser-contexts
## What Matters Most
### 1. Keep scenarios isolated
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
Apply it like this:
- do not depend on another scenario having run first
- do not persist ad hoc scenario state outside `DifyWorld`
- do not couple ordinary scenarios to `@fresh` behavior
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
### 2. Prefer user-facing locators
Playwright recommends built-in locators that reflect what users perceive on the page.
Preferred order in this repository:
1. `getByRole`
2. `getByLabel`
3. `getByPlaceholder`
4. `getByText`
5. `getByTestId` when an explicit test contract is the most stable option
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
Also remember:
- repeated content usually needs scoping to a stable container
- exact text matching is often too brittle when role/name or label already exists
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
### 3. Use web-first assertions
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
Prefer:
- `await expect(page).toHaveURL(...)`
- `await expect(locator).toBeVisible()`
- `await expect(locator).toBeHidden()`
- `await expect(locator).toBeEnabled()`
- `await expect(locator).toHaveText(...)`
Avoid:
- `expect(await locator.isVisible()).toBe(true)`
- custom polling loops for DOM state
- `waitForTimeout` as synchronization
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
### 4. Let actions wait for actionability
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
Good pattern:
- assert a meaningful visible state when that is part of the behavior
- then click/fill/select via locator APIs
Bad pattern:
- stack arbitrary waits before every action
- wait on unstable implementation details instead of the visible state the user cares about
### 5. Match debugging to the current suite
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
- full-page screenshots
- page HTML
- console errors
- page errors
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
## Review Questions
- Would this locator survive DOM refactors that do not change user-visible behavior?
- Is this assertion using Playwright's retrying semantics?
- Is any explicit wait masking a real readiness problem?
- Does this code preserve per-scenario isolation?
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?

View File

@ -0,0 +1 @@
../../.agents/skills/e2e-cucumber-playwright

View File

@ -6,14 +6,7 @@ on:
- "main"
paths:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- packages/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
- .nvmrc
concurrency:
group: docker-build-${{ github.head_ref || github.run_id }}

View File

@ -92,6 +92,7 @@ jobs:
vdb:
- 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'

View File

@ -89,7 +89,7 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -81,12 +81,12 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate
api/providers/vdb/vdb-chroma/tests/integration_tests \
api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/providers/vdb/vdb-weaviate/tests/integration_tests

View File

@ -69,8 +69,6 @@ ignore = [
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
@ -84,7 +82,6 @@ ignore = [
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
@ -93,29 +90,16 @@ ignore = [
]
[lint.per-file-ignores]
"__init__.py" = [
"F401", # unused-import
"F811", # redefined-while-unused
]
"configs/*" = [
"N802", # invalid-function-name
]
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@ -21,8 +21,9 @@ RUN apt-get update \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
# Install Python dependencies (workspace members under providers/vdb/)
COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev
# production stage

View File

@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
import qdrant_client
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.")

View File

@ -1,4 +1,3 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
default="public",
)
HOLOGRES_TOKENIZER: TokenizerType = Field(
HOLOGRES_TOKENIZER: str = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
HOLOGRES_DISTANCE_METHOD: str = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)

View File

@ -1,5 +1,7 @@
"""Configuration for InterSystems IRIS vector database."""
from typing import Any
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate IRIS configuration values.
Args:

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
@account_initialization_required
def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
prompt_args: AdvancedPromptTemplateArgs = {
"app_mode": args.app_mode,
"model_mode": args.model_mode,
"model_name": args.model_name,
"has_context": args.has_context,
}
return AdvancedPromptTemplateService.get_prompt(prompt_args)

View File

@ -26,13 +26,13 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")

View File

@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,

View File

@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None:
@ -214,7 +214,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
Get advanced chat app workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
# Default to DEBUGGING if not specified
triggered_from = (
@ -356,7 +360,11 @@ class WorkflowRunListApi(Resource):
Get workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (

View File

@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
node_id = args.node_id
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get webhook trigger for this app and node
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger)
@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with sessionmaker(db.engine).begin() as session:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get all triggers for this app using select API
triggers = (
session.execute(

View File

@ -1,7 +1,10 @@
import logging
import flask_login
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@ -42,12 +45,13 @@ from libs.token import (
)
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
@ -91,10 +95,12 @@ class LoginApi(Resource):
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
@ -110,14 +116,20 @@ class LoginApi(Resource):
invitee_email = data.get("email") if data else None
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
_log_console_login_failure(
email=normalized_email,
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
)
raise InvalidEmailError()
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
except services.errors.account.AccountLoginError:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource):
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != args.code:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args.token)
try:
account = _get_account_with_case_fallback(original_email)
except Unauthorized as exc:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource):
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback(
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Console login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import TypedDict
from flask import request
@ -13,6 +14,14 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationLangContent(TypedDict, total=False):
lang: str
title: str
subtitle: str
body: str
titlePicUrl: str
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict:
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
return (
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
)
class DismissNotificationPayload(BaseModel):
@ -71,7 +82,7 @@ class NotificationApi(Resource):
notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal
from typing import Any, Literal
import pytz
from flask import request
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")

View File

@ -20,7 +20,7 @@ from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService
from services.operation_service import OperationService, UtmInfo
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity:
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
try:
variable_type = VariableEntityType(variable_type_raw)
except ValueError as e:
raise MCPRequestError(
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
) from e
variable = item[variable_type_raw]
return VariableEntity(
type=variable_type,
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
json_schema=variable.get("json_schema"),
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)

View File

@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
summaries: dict[str, str | None] = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
result: list[dict[str, Any]] = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result

View File

@ -5,6 +5,7 @@ Web App Human Input Form APIs.
import json
import logging
from datetime import datetime
from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
class FormDefinitionPayload(TypedDict):
form_content: Any
inputs: Any
resolved_default_values: dict[str, str]
user_actions: Any
expiration_time: int
site: NotRequired[dict]
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump()
payload = {
payload: FormDefinitionPayload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),

View File

@ -1,7 +1,10 @@
import logging
from flask import make_response, request
from flask_restx import Resource
from jwt import InvalidTokenError
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@ -20,7 +23,7 @@ from controllers.console.wraps import (
)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import EmailStr
from libs.helper import EmailStr, extract_remote_ip
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
@ -29,9 +32,11 @@ from libs.token import (
)
from services.account_service import AccountService
from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
@field_validator("password")
@ -76,14 +81,18 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
payload = LoginPayload.model_validate(web_ns.payload or {})
normalized_email = payload.email.lower()
try:
account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError()
except services.errors.account.AccountNotFoundError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource):
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != payload.code:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(payload.token)
account = WebAppAuthService.get_user_through_email(token_email)
try:
account = WebAppAuthService.get_user_through_email(token_email)
except Unauthorized as exc:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
if not account:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource):
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Web login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -1,5 +1,6 @@
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any
from flask import make_response, request
from flask_restx import Resource
@ -103,21 +104,23 @@ class PassportResource(Resource):
return response
def decode_enterprise_webapp_user_id(jwt_token: str | None):
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
"""
Decode the enterprise user session from the Authorization header.
"""
if not jwt_token:
return None
decoded = PassportService().verify(jwt_token)
decoded: dict[str, Any] = PassportService().verify(jwt_token)
source = decoded.get("token_source")
if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
def exchange_token_for_existing_web_user(
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
):
"""
Exchange a token for an existing web user session.
"""

View File

@ -1,4 +1,4 @@
from typing import cast
from typing import Any, cast
from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select
@ -113,12 +113,12 @@ class AppSiteInfo:
}
def serialize_site(site: Site) -> dict:
def serialize_site(site: Site) -> dict[str, Any]:
"""Serialize Site model using the same schema as AppSiteApi."""
return cast(dict, marshal(site, AppSiteApi.site_fields))
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))

View File

@ -138,7 +138,9 @@ class DatasetConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for dataset feature
@ -172,7 +174,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
"""
Extract dataset config for legacy compatibility

View File

@ -108,7 +108,7 @@ class ModelConfigManager:
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict):
def validate_model_completion_params(cls, cp: dict[str, Any]):
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")

View File

@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate pre_prompt and set defaults for prompt feature
depending on the config['model']
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict):
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
"""
Validate post_prompt and set defaults for prompt feature

View File

@ -1,5 +1,5 @@
import re
from typing import cast
from typing import Any, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for user input form
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
return config, related_config_keys
@classmethod
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for user input form
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
return config, ["user_input_form"]
@classmethod
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
def validate_external_data_tools_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for external data fetch feature

View File

@ -30,7 +30,7 @@ class FileUploadConfigManager:
return FileUploadConfig.model_validate(file_upload_dict)
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for file upload feature

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, ValidationError
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError:

View File

@ -1,6 +1,9 @@
from typing import Any
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[str, list]:
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
return opening_statement, suggested_questions_list
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for opening statement feature

View File

@ -1,6 +1,9 @@
from typing import Any
class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict:
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
return show_retrieve_source
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for retriever resource feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SpeechToTextConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
return speech_to_text
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for speech to text feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
def convert(cls, config: dict[str, Any]) -> bool:
"""
Convert model config to model config
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
return suggested_questions_after_answer
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for suggested questions feature

View File

@ -1,9 +1,11 @@
from typing import Any
from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager:
@classmethod
def convert(cls, config: dict):
def convert(cls, config: dict[str, Any]):
"""
Convert model config to model config
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
return text_to_speech
@classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for text to speech feature

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import cast
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -1,3 +1,5 @@
from typing import Any
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
def config_validate(
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> dict[str, Any]:
"""
Validate for pipeline config

View File

@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None,
next_page_parameters: dict[str, Any] | None = None,
):
"""
Get files in a folder.

View File

@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict | None = None
extras: dict[str, Any] | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict = Field(default_factory=dict)
extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict | None = None
extras: dict[str, Any] | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus

View File

@ -14,7 +14,7 @@ class DatasourceApiEntity(BaseModel):
description: I18nObject
parameters: list[DatasourceParameter] | None = None
labels: list[str] = Field(default_factory=list)
output_schema: dict | None = None
output_schema: dict[str, Any] | None = None
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
@ -30,7 +30,7 @@ class DatasourceProviderApiEntityDict(TypedDict):
icon: str | dict
label: I18nObjectDict
type: str
team_credentials: dict | None
team_credentials: dict[str, Any] | None
is_team_authorization: bool
allow_delete: bool
datasources: list[Any]
@ -45,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel):
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: dict | None = None
original_credentials: dict | None = None
masked_credentials: dict[str, Any] | None = None
original_credentials: dict[str, Any] | None = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")

View File

@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
output_schema: dict | None = None
output_schema: dict[str, Any] | None = None
@field_validator("parameters", mode="before")
@classmethod
@ -192,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel):
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None
tool_config: dict | None = None
tool_config: dict[str, Any] | None = None
@classmethod
def empty(cls) -> DatasourceInvokeMeta:
@ -242,7 +242,7 @@ class OnlineDocumentPage(BaseModel):
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: dict | None = Field(None, description="The page icon")
page_icon: dict[str, Any] | None = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: str | None = Field(None, description="The parent page id")
@ -301,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel):
Get website crawl request
"""
crawl_parameters: dict = Field(..., description="The crawl parameters")
crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters")
class WebSiteInfoDetail(BaseModel):
@ -358,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel):
bucket: str | None = Field(None, description="The file bucket")
files: list[OnlineDriveFile] = Field(..., description="The file list")
is_truncated: bool = Field(False, description="Whether the result is truncated")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesRequest(BaseModel):
@ -369,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
bucket: str | None = Field(None, description="The file bucket")
prefix: str = Field(..., description="The parent folder ID")
max_keys: int = Field(20, description="Page size for pagination")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesResponse(BaseModel):

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: dict | None = None
data_source_info: dict[str, Any] | None = None
name: str
indexing_status: str
error: str | None = None

View File

@ -6,6 +6,7 @@ import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
"""
Get current credentials.
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
return session.execute(stmt).scalar_one_or_none()
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
"""
Get provider credentials.
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
"""
Validate custom credentials.
:param credentials: provider credentials
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None):
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
def update_provider_credential(
self,
credentials: dict,
credentials: dict[str, Any],
credential_id: str,
credential_name: str | None,
):
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
def _get_specific_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str
) -> dict | None:
) -> dict[str, Any] | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> dict[str, Any] | None:
"""
Get custom model credentials.
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
self,
model_type: ModelType,
model: str,
credentials: dict,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
) -> None:
"""
Create a custom model credential.
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
self,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
credential_id: str,
) -> None:
"""
Update a custom model credential.
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
) -> AIModelEntity | None:
"""
Get model schema
"""
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from enum import StrEnum, auto
from typing import Union
from typing import Any, Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel, ConfigDict, Field
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
enabled: bool
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
credentials: dict | None = None
credentials: dict[str, Any] | None = None
class CustomProviderConfiguration(BaseModel):
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
Model class for provider custom configuration.
"""
credentials: dict
credentials: dict[str, Any]
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] = []
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict | None
credentials: dict[str, Any] | None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = []
@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
id: str
name: str
credentials: dict
credentials: dict[str, Any]
credential_source_type: str | None = None
credential_id: str | None = None

View File

@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict):
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.

View File

@ -1,6 +1,7 @@
import json
from enum import StrEnum
from json import JSONDecodeError
from typing import Any
from extensions.ext_redis import redis_client
@ -15,7 +16,7 @@ class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> dict | None:
def get(self) -> dict[str, Any] | None:
"""
Get cached model provider credentials.
@ -33,7 +34,7 @@ class ProviderCredentialsCache:
else:
return None
def set(self, credentials: dict):
def set(self, credentials: dict[str, Any]):
"""
Cache model provider credentials.

View File

@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC):
"""Generate cache key based on subclass implementation"""
pass
def get(self) -> dict | None:
def get(self) -> dict[str, Any] | None:
"""Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key)
if cached_credentials:
@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""
def get(self) -> dict | None:
def get(self) -> dict[str, Any] | None:
"""Get cached provider credentials"""
return None

View File

@ -1,6 +1,7 @@
import json
from enum import StrEnum
from json import JSONDecodeError
from typing import Any
from extensions.ext_redis import redis_client
@ -18,7 +19,7 @@ class ToolParameterCache:
f":identity_id:{identity_id}"
)
def get(self) -> dict | None:
def get(self) -> dict[str, Any] | None:
"""
Get cached model provider credentials.
@ -36,7 +37,7 @@ class ToolParameterCache:
else:
return None
def set(self, parameters: dict):
def set(self, parameters: dict[str, Any]):
"""Cache model provider credentials."""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))

View File

@ -735,7 +735,9 @@ class IndexingRunner:
@staticmethod
def _update_document_index_status(
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
document_id: str,
after_indexing_status: IndexingStatus,
extra_update_params: dict[Any, Any] | None = None,
):
"""
Update the document indexing status.
@ -762,7 +764,7 @@ class IndexingRunner:
db.session.commit()
@staticmethod
def _update_segments_by_document(dataset_document_id: str, update_params: dict):
def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]):
"""
Update the document segment by document id.
"""

View File

@ -200,7 +200,7 @@ def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,
structured_output_schema: Mapping,
model_parameters: dict,
model_parameters: dict[str, Any],
rules: list[ParameterRule],
):
"""
@ -224,7 +224,7 @@ def _handle_native_json_schema(
return model_parameters
def _set_response_format(model_parameters: dict, rules: list):
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]):
"""
Set the appropriate response format parameter based on model rules.
@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict):
def remove_additional_properties(schema: dict[str, Any]):
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.

View File

@ -77,7 +77,7 @@ class ModelInstance:
@staticmethod
def _get_load_balancing_manager(
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
) -> Optional["LBModelManager"]:
"""
Get load balancing model credentials
@ -115,7 +115,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True] = True,
@ -126,7 +126,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
@ -137,7 +137,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@ -147,7 +147,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
@ -528,7 +528,7 @@ class LBModelManager:
model_type: ModelType,
model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: dict | None = None,
managed_credentials: dict[str, Any] | None = None,
):
"""
Load balancing model manager

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
inputs: dict = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
query: str = ""
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
name: str = "api"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
if self.config is None:
raise ValueError("The config is not set.")
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
inputs: dict = Field(default_factory=dict)
inputs: dict[str, Any] = Field(default_factory=dict)
query: str = ""
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
super().__init__(tenant_id, config)
self.app_id = app_id
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
"""
Validate the incoming form config data.
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@abstractmethod
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory:
__extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict):
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -24,7 +26,7 @@ class ModerationFactory:
# FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review

View File

@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType
from core.model_manager import ModelManager
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
"""
cls._validate_inputs_and_outputs_config(config, True)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict):
def _is_violated(self, inputs: dict[str, Any]):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance(

View File

@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
attributes: dict[str, str] = {}
@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
set_attribute(path, value)
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
def set_tool_call_attributes(
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
) -> None:
"""Extract and assign tool call details safely."""
if not tool_call:
return

View File

@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
@staticmethod
def _get_completion_start_time(
start_time: datetime | None, time_to_first_token: float | int | None
) -> datetime | None:
"""Convert a relative TTFT value in seconds into Langfuse's absolute completion start time."""
if start_time is None or time_to_first_token is None:
return None
try:
ttft_seconds = float(time_to_first_token)
except (TypeError, ValueError):
return None
if ttft_seconds < 0:
return None
return start_time + timedelta(seconds=ttft_seconds)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance):
total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
completion_start_time = None
try:
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
usage_data = process_data.get("usage")
if not isinstance(usage_data, dict):
usage_data = outputs.get("usage")
if not isinstance(usage_data, dict):
usage_data = {}
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
completion_start_time = self._get_completion_start_time(
created_at, usage_data.get("time_to_first_token")
)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_id=trace_id,
model=process_data.get("model_name"),
start_time=created_at,
completion_start_time=completion_start_time,
end_time=finished_at,
input=inputs,
output=outputs,
@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance):
unit=UnitEnum.TOKENS,
totalCost=message_data.total_price,
)
completion_start_time = self._get_completion_start_time(
trace_info.start_time,
trace_info.gen_ai_server_time_to_first_token,
)
langfuse_generation_data = LangfuseGeneration(
name="llm",
trace_id=trace_id,
start_time=trace_info.start_time,
completion_start_time=completion_start_time,
end_time=trace_info.end_time,
model=message_data.model_id,
input=trace_info.inputs,

View File

@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance):
return inputs, attributes
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
"""Parse KR outputs and attributes from KR workflow node"""
retrieved = outputs.get("result", [])
@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance):
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _get_message_user_id(self, metadata: dict) -> str | None:
def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
if (end_user_id := metadata.get("from_end_user_id")) and (
end_user_data := db.session.get(EndUser, end_user_id)
):
@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance):
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
def _set_trace_metadata(self, span: Span, metadata: dict):
def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
token = None
try:
# NB: Set span in context such that we can use update_current_trace() API
@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance):
return messages
return prompts # Fallback to original format
def _parse_single_message(self, item: dict):
def _parse_single_message(self, item: dict[str, Any]):
"""Postprocess single message format to be standard chat message"""
role = item.get("role", "user")
msg = {"role": role, "content": item.get("text", "")}

View File

@ -3,7 +3,7 @@ import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import cast
from typing import Any, cast
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from opik import Opik, Trace
@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance):
self.add_span(span_data)
def add_trace(self, opik_trace_data: dict) -> Trace:
def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
try:
trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully")
@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance):
except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict):
def add_span(self, opik_span_data: dict[str, Any]):
try:
self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully")

View File

@ -324,7 +324,7 @@ class OpsTraceManager:
@classmethod
def encrypt_tracing_config(
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any], current_trace_config=None
):
"""
Encrypt tracing config.
@ -363,7 +363,7 @@ class OpsTraceManager:
return encrypted_config.model_dump()
@classmethod
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
"""
Decrypt tracing config
:param tenant_id: tenant id
@ -408,7 +408,7 @@ class OpsTraceManager:
return dict(decrypted_config)
@classmethod
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict[str, Any]):
"""
Decrypt tracing config
:param tracing_provider: tracing provider
@ -581,7 +581,7 @@ class OpsTraceManager:
return app_trace_config
@staticmethod
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
def check_trace_config_is_effective(tracing_config: dict[str, Any], tracing_provider: str):
"""
Check trace config is effective
:param tracing_config: tracing config
@ -596,7 +596,7 @@ class OpsTraceManager:
return trace_instance(config).api_check()
@staticmethod
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
def get_trace_config_project_key(tracing_config: dict[str, Any], tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
@ -611,7 +611,7 @@ class OpsTraceManager:
return trace_instance(config).get_project_key()
@staticmethod
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
def get_trace_config_project_url(tracing_config: dict[str, Any], tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
@ -1322,8 +1322,8 @@ class TraceTask:
error=error,
)
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
node_data: dict = kwargs.get("node_execution_data", {})
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict[str, Any]:
node_data: dict[str, Any] = kwargs.get("node_execution_data", {})
if not node_data:
return {}
@ -1431,7 +1431,7 @@ class TraceTask:
return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict:
def _extract_streaming_metrics(self, message_data) -> dict[str, Any]:
if not message_data.message_metadata:
return {}

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field, model_validator
@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity):
entity of an endpoint
"""
settings: dict
settings: dict[str, Any]
tenant_id: str
plugin_id: str
expired_at: datetime

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from pydantic import BaseModel, Field, computed_field, model_validator
@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel):
@model_validator(mode="before")
@classmethod
def transform_declaration(cls, data: dict):
def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]:
if "endpoint" in data and not data["endpoint"]:
del data["endpoint"]
if "model" in data and not data["model"]:

View File

@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_category(cls, values: dict):
def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]:
# auto detect category
if values.get("tool"):
values["category"] = PluginCategory.Tool

View File

@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel):
"""
result: bool
credentials: dict | None = None
credentials: dict[str, Any] | None = None
class PluginModelSchemaEntity(BaseModel):

View File

@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel):
tool_type: Literal["builtin", "workflow", "api", "mcp"]
provider: str
tool: str
tool_parameters: dict
tool_parameters: dict[str, Any]
credential_id: str | None = None
@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel):
opt: Literal["encrypt", "decrypt", "clear"]
namespace: Literal["endpoint"]
identity: str
data: dict = Field(default_factory=dict)
data: dict[str, Any] = Field(default_factory=dict)
config: list[BasicProviderConfig] = Field(default_factory=list)

View File

@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient):
tool_provider_id = DatasourceProviderID(provider_id)
def transformer(json_response: dict[str, Any]) -> dict:
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
data = json_response.get("data")
if data:
for datasource in data.get("declaration", {}).get("datasources", []):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
def create_endpoint(
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
self,
tenant_id: str,
user_id: str,
plugin_unique_identifier: str,
name: str,
settings: dict[str, Any],
) -> bool:
"""
Create an endpoint for the given plugin.
@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient):
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
)
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
def update_endpoint(
self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]
) -> bool:
"""
Update the settings of the given endpoint.
"""

View File

@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
) -> AIModelEntity | None:
"""
Get model schema
@ -80,7 +80,7 @@ class PluginModelClient(BasePluginClient):
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the provider
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
) -> bool:
"""
validate the credentials of the provider
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
model_parameters: dict | None = None,
model_parameters: dict[str, Any] | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
texts: list[str],
input_type: str,
) -> EmbeddingResult:
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
documents: list[dict],
input_type: str,
) -> EmbeddingResult:
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
"""
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None = None,
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Generator[bytes, None, None]:
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
language: str | None = None,
):
"""
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
"""
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict,
credentials: dict[str, Any],
text: str,
) -> bool:
"""

View File

@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import Any
from requests import HTTPError
@ -263,7 +264,7 @@ class PluginInstaller(BasePluginClient):
original_plugin_unique_identifier: str,
new_plugin_unique_identifier: str,
source: PluginInstallationSource,
meta: dict,
meta: dict[str, Any],
) -> PluginInstallTaskStartResponse:
"""
Upgrade a plugin.

View File

@ -96,11 +96,11 @@ class SimplePromptTransform(PromptTransform):
app_mode: AppMode,
model_config: ModelConfigWithCredentialsEntity,
pre_prompt: str,
inputs: dict,
inputs: dict[str, Any],
query: str | None = None,
context: str | None = None,
histories: str | None = None,
) -> tuple[str, dict]:
) -> tuple[str, dict[str, Any]]:
# get prompt template
prompt_template_config = self.get_prompt_template(
app_mode=app_mode,
@ -187,7 +187,7 @@ class SimplePromptTransform(PromptTransform):
self,
app_mode: AppMode,
pre_prompt: str,
inputs: dict,
inputs: dict[str, Any],
query: str,
context: str | None,
files: Sequence["File"],
@ -234,7 +234,7 @@ class SimplePromptTransform(PromptTransform):
self,
app_mode: AppMode,
pre_prompt: str,
inputs: dict,
inputs: dict[str, Any],
query: str,
context: str | None,
files: Sequence["File"],

View File

@ -856,7 +856,7 @@ class ProviderManager:
secret_variables: list[str],
cache_type: ProviderCredentialsCacheType,
is_provider: bool = False,
) -> dict:
) -> dict[str, Any]:
"""Get and decrypt credentials with caching."""
credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,

View File

@ -174,8 +174,8 @@ class RetrievalService:
cls,
dataset_id: str,
query: str,
external_retrieval_model: dict | None = None,
metadata_filtering_conditions: dict | None = None,
external_retrieval_model: dict[str, Any] | None = None,
metadata_filtering_conditions: dict[str, Any] | None = None,
):
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)

View File

@ -0,0 +1,87 @@
"""Vector store backend discovery.
Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package
declares third-party dependencies and registers ``importlib`` entry points in group
``dify.vector_backends`` (see each package's ``pyproject.toml``).
Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol
remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``).
Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a
distribution; entry points take precedence when both exist.
After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``.
"""
from __future__ import annotations
import importlib
import logging
from importlib.metadata import entry_points
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
logger = logging.getLogger(__name__)
_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {}
# module_path:class_name — optional fallback when no distribution registers the backend.
_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {}
def clear_vector_factory_cache() -> None:
"""Drop lazily loaded factories (for tests or plugin reload)."""
_VECTOR_FACTORY_CACHE.clear()
def _vector_backend_entry_points():
return entry_points().select(group="dify.vector_backends")
def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None:
for ep in _vector_backend_entry_points():
if ep.name != vector_type:
continue
try:
loaded = ep.load()
except Exception:
logger.exception("Failed to load vector backend entry point %s", ep.name)
raise
return loaded # type: ignore[return-value]
return None
def _unsupported(vector_type: str) -> ValueError:
installed = sorted(ep.name for ep in _vector_backend_entry_points())
available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed."
return ValueError(
f"Vector store {vector_type!r} is not supported.{available_msg} "
"Install a plugin (uv sync --group vdb-all, or vdb-<backend> per api/pyproject.toml), "
"or register a dify.vector_backends entry point."
)
def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]:
target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type)
if not target:
raise _unsupported(vector_type)
module_path, _, attr = target.partition(":")
module = importlib.import_module(module_path)
return getattr(module, attr) # type: ignore[no-any-return]
def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]:
"""Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value."""
if vector_type in _VECTOR_FACTORY_CACHE:
return _VECTOR_FACTORY_CACHE[vector_type]
plugin_cls = _load_plugin_factory(vector_type)
if plugin_cls is not None:
_VECTOR_FACTORY_CACHE[vector_type] = plugin_cls
return plugin_cls
cls = _load_builtin_factory(vector_type)
_VECTOR_FACTORY_CACHE[vector_type] = cls
return cls

View File

@ -9,6 +9,7 @@ from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
@ -85,137 +86,7 @@ class Vector:
@staticmethod
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
match vector_type:
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.ALIBABACLOUD_MYSQL:
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVectorFactory,
)
return AlibabaCloudMySQLVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.VASTBASE:
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
return VastbaseVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.COUCHBASE:
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
return CouchbaseVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
return BaiduVectorFactory
case VectorType.VIKINGDB:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
return VikingDBVectorFactory
case VectorType.UPSTASH:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.LINDORM:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case VectorType.OPENGAUSS:
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
return OpenGaussFactory
case VectorType.TABLESTORE:
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
return TableStoreVectorFactory
case VectorType.HUAWEI_CLOUD:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory
case VectorType.MATRIXONE:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case VectorType.CLICKZETTA:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case VectorType.IRIS:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case VectorType.HOLOGRES:
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
return HologresVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
return get_vector_factory_class(vector_type)
def create(self, texts: list | None = None, **kwargs):
if texts:

View File

@ -1,10 +1,19 @@
"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``).
:class:`AbstractVectorTest` and helper functions live here so package tests can import
``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the
``tests.*`` package.
The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is
auto-discovered by pytest for all package tests.
"""
import uuid
from unittest.mock import MagicMock
import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions import ext_redis
from models.dataset import Dataset
@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document:
return doc
@pytest.fixture
def setup_mock_redis():
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractVectorTest:
vector: BaseVector
def __init__(self):
self.vector = None
self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
self.example_doc_id = str(uuid.uuid4())

View File

@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings):
return embedding_results # type: ignore
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
class Embeddings(ABC):
@ -20,7 +21,7 @@ class Embeddings(ABC):
raise NotImplementedError
@abstractmethod
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
"""Embed multimodal query."""
raise NotImplementedError

View File

@ -1,6 +1,7 @@
"""Abstract interface for document loader implementations."""
import csv
from typing import Any
import pandas as pd
@ -23,7 +24,7 @@ class CSVExtractor(BaseExtractor):
encoding: str | None = None,
autodetect_encoding: bool = False,
source_column: str | None = None,
csv_args: dict | None = None,
csv_args: dict[str, Any] | None = None,
):
"""Initialize with file path."""
self._file_path = file_path

View File

@ -54,8 +54,8 @@ class BaseAPIClient:
self,
method: str,
endpoint: str,
query_params: dict | None = None,
data: dict | None = None,
query_params: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
**kwargs,
) -> Response:
stream = kwargs.pop("stream", False)
@ -66,19 +66,25 @@ class BaseAPIClient:
return self.session.request(method, url, params=query_params, json=data, **kwargs)
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
def _get(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
return self._request("GET", endpoint, query_params=query_params, **kwargs)
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
def _post(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs)
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
def _put(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs)
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
def _delete(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
return self._request("DELETE", endpoint, query_params=query_params, **kwargs)
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
def _patch(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs)
@ -99,7 +105,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
finally:
response.close()
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
def process_response(self, response: Response) -> dict[str, Any] | bytes | list[Any] | None | Generator:
if response.status_code == 401:
raise WaterCrawlAuthenticationError(response)
@ -186,7 +192,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
yield from generator
def get_crawl_request_results(
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict[str, Any] | None = None
):
query_params = query_params or {}
query_params.update({"page": page or 1, "page_size": page_size or 25})
@ -210,7 +216,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
if event_data["type"] == "result":
return event_data["data"]
def download_result(self, result_object: dict):
def download_result(self, result_object: dict[str, Any]):
response = httpx.get(result_object["result"], timeout=None)
try:
response.raise_for_status()

View File

@ -120,7 +120,7 @@ class WaterCrawlProvider:
}
def _get_results(
self, crawl_request_id: str, query_params: dict | None = None
self, crawl_request_id: str, query_params: dict[str, Any] | None = None
) -> Generator[WatercrawlDocumentData, None, None]:
page = 0
page_size = 100

View File

@ -875,7 +875,11 @@ class DatasetRetrieval:
return retrieval_resource_list
def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
self,
flask_app: Flask,
documents: list[Document],
message_id: str | None = None,
timer: dict[str, Any] | None = None,
):
"""Handle retrieval end."""
with flask_app.app_context():
@ -980,7 +984,7 @@ class DatasetRetrieval:
self._send_trace_task(message_id, documents, timer)
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None):
"""Send trace task if trace manager is available."""
trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
@ -1142,7 +1146,7 @@ class DatasetRetrieval:
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
inputs: dict[str, Any],
) -> list[DatasetRetrieverBaseTool] | None:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
@ -1337,7 +1341,7 @@ class DatasetRetrieval:
metadata_filtering_mode: str,
metadata_model_config: ModelConfig,
metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict,
inputs: dict[str, Any],
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
@ -1417,7 +1421,7 @@ class DatasetRetrieval:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str:
if not inputs:
return text

View File

@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
return _case_routing
def __getattr__(name: str) -> dict:
def __getattr__(name: str) -> Any:
"""Lazy module-level access to routing tables."""
if name == "CASE_ROUTING":
return _get_case_routing()

View File

@ -198,7 +198,7 @@ class Tool(ABC):
message=ToolInvokeMessage.TextMessage(text=text),
)
def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage:
def create_blob_message(self, blob: bytes, meta: dict[str, Any] | None = None) -> ToolInvokeMessage:
"""
create a blob message
@ -212,7 +212,7 @@ class Tool(ABC):
meta=meta,
)
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
def create_json_message(self, object: dict[str, Any], suppress_output: bool = False) -> ToolInvokeMessage:
"""
create a json message
"""

View File

@ -149,7 +149,7 @@ class ToolInvokeMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
json_object: dict | list
json_object: dict[str, Any] | list[Any]
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
@ -337,7 +337,7 @@ class ToolParameter(PluginParameter):
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: str | None = None
# MCP object and array type parameters use this field to store the schema
input_schema: dict | None = None
input_schema: dict[str, Any] | None = None
@classmethod
def get_simple_instance(
@ -463,7 +463,7 @@ class ToolInvokeMeta(BaseModel):
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None
tool_config: dict | None = None
tool_config: dict[str, Any] | None = None
@classmethod
def empty(cls) -> ToolInvokeMeta:

View File

@ -85,7 +85,8 @@ class ToolEngine:
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
invocation_meta_dict: dict[str, Any],
messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None],
):
for message in messages:
if isinstance(message, ToolInvokeMeta):
@ -200,7 +201,7 @@ class ToolEngine:
@staticmethod
def _invoke(
tool: Tool,
tool_parameters: dict,
tool_parameters: dict[str, Any],
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,

View File

@ -33,7 +33,7 @@ class DatasetRetrieverTool(Tool):
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
inputs: dict[str, Any],
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool

View File

@ -4,6 +4,7 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Any
from uuid import UUID
import numpy as np
@ -50,7 +51,7 @@ def safe_json_value(v):
return v
def safe_json_dict(d: dict):
def safe_json_dict(d: dict[str, Any]):
if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()}
@ -196,11 +197,11 @@ class ToolFileMessageTransformer:
@staticmethod
def _with_tool_file_meta(
meta: dict | None,
meta: dict[str, Any] | None,
*,
tool_file_id: str | None = None,
url: str | None = None,
) -> dict:
) -> dict[str, Any]:
normalized_meta = meta.copy() if meta is not None else {}
resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url)
if resolved_tool_file_id and "tool_file_id" not in normalized_meta:

View File

@ -32,7 +32,7 @@ class OpenAPISpecDict(TypedDict):
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
openapi: Mapping[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@ -236,7 +236,7 @@ class ApiBasedToolSchemaParser:
return value
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
def _get_tool_parameter_type(parameter: dict[str, Any]) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
typ: str | None = None
if parameter.get("format") == "binary":
@ -265,7 +265,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
yaml: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
@ -278,14 +278,14 @@ class ApiBasedToolSchemaParser:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
openapi: dict[str, Any] = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
swagger: dict[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> OpenAPISpecDict:
warning = warning or {}
"""
@ -351,7 +351,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
json: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
@ -392,7 +392,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
content: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
"""
auto parse to tool bundle

View File

@ -277,7 +277,7 @@ class WorkflowTool(Tool):
session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""
transform the tool parameters
@ -323,7 +323,7 @@ class WorkflowTool(Tool):
return parameters_result, files
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
def _extract_files(self, outputs: dict[str, Any]) -> tuple[dict[str, Any], list[File]]:
"""
extract files from the result
@ -355,7 +355,7 @@ class WorkflowTool(Tool):
return result, files
def _update_file_mapping(self, file_dict: dict):
def _update_file_mapping(self, file_dict: dict[str, Any]):
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
match transfer_method:

View File

@ -43,15 +43,20 @@ class IndexProcessorProtocol(Protocol):
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
summary_index_setting: dict[str, Any] | None = None,
) -> IndexingResultDict: ...
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
self,
chunks: Any,
dataset_id: str,
document_id: str,
chunk_structure: str,
summary_index_setting: dict[str, Any] | None,
) -> Preview: ...
class SummaryIndexServiceProtocol(Protocol):
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict[str, Any] | None = None
) -> None: ...

View File

@ -1,4 +1,4 @@
from typing import Literal, Union
from typing import Any, Literal, Union
from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import NodeType
@ -16,7 +16,7 @@ class TriggerScheduleNodeData(BaseNodeData):
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
visual_config: dict | None = Field(default=None, description="Visual configuration details")
visual_config: dict[str, Any] | None = Field(default=None, description="Visual configuration details")
timezone: str = Field(default="UTC", description="Timezone for schedule execution")

View File

@ -75,7 +75,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs=outputs,
)
def generate_file_var(self, param_name: str, file: dict):
def generate_file_var(self, param_name: str, file: dict[str, Any]):
file_id = resolve_file_record_id(file.get("reference") or file.get("related_id"))
transfer_method_value = file.get("transfer_method")
if transfer_method_value:
@ -147,7 +147,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self.node_data.content_type == ContentType.BINARY:
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
raw_data: dict[str, Any] = webhook_data.get("body", {}).get("raw", {})
file_var = self.generate_file_var(param_name, raw_data)
if file_var:
outputs[param_name] = file_var

View File

@ -10,6 +10,7 @@ import tempfile
from collections.abc import Generator
from io import BytesIO
from pathlib import Path
from typing import Any
import clickzetta
from pydantic import BaseModel, model_validator
@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
def validate_config(cls, values: dict[str, Any]):
"""Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables,

View File

@ -65,7 +65,7 @@ class FileMetadata:
return data
@classmethod
def from_dict(cls, data: dict) -> FileMetadata:
def from_dict(cls, data: dict[str, Any]) -> FileMetadata:
"""Create instance from dictionary"""
data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"])
@ -459,7 +459,7 @@ class FileLifecycleManager:
newest_file=None,
)
def _create_version_backup(self, filename: str, metadata: dict):
def _create_version_backup(self, filename: str, metadata: dict[str, Any]):
"""Create version backup"""
try:
# Read current file content
@ -487,7 +487,7 @@ class FileLifecycleManager:
logger.warning("Failed to load metadata: %s", e)
return {}
def _save_metadata(self, metadata_dict: dict):
def _save_metadata(self, metadata_dict: dict[str, Any]):
"""Save metadata file"""
try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)

Some files were not shown because too many files have changed in this diff Show More