Merge remote-tracking branch 'origin/main'

# Conflicts:
#	api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py
This commit is contained in:
FFXN 2026-04-14 13:54:31 +08:00
commit 03660c19ef
466 changed files with 9399 additions and 6724 deletions

View File

@ -0,0 +1,79 @@
---
name: e2e-cucumber-playwright
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
---
# Dify E2E Cucumber + Playwright
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
## Scope
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
- Do not use this skill for backend test or API review tasks under `api/`.
## Read Order
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
2. Read only the files directly involved in the task:
- target `.feature` files under `e2e/features/`
- related step files under `e2e/features/step-definitions/`
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
## Local Rules
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
- Browser session behavior comes from `features/support/hooks.ts`:
- default: authenticated session with shared storage state
- `@unauthenticated`: clean browser context
- `@authenticated`: readability/selective-run tag only unless implementation changes
- `@fresh`: only for `e2e:full*` flows
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
## Workflow
1. Rebuild local context.
- Inspect the target feature area.
- Reuse an existing step when wording and behavior already match.
- Add a new step only for a genuinely new user action or assertion.
- Keep edits close to the current capability folder unless the step is broadly reusable.
2. Write behavior-first scenarios.
- Describe user-observable behavior, not DOM mechanics.
- Keep each scenario focused on one workflow or outcome.
- Keep scenarios independent and re-runnable.
3. Write step definitions in the local style.
- Keep one step to one user-visible action or one assertion.
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
- Scope locators to stable containers when the page has repeated elements.
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
4. Use Playwright in the local style.
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
- Use web-first `expect(...)` assertions.
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
5. Validate narrowly.
- Run the narrowest tagged scenario or flow that exercises the change.
- Run `pnpm -C e2e check`.
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
## Review Checklist
- Does the scenario describe behavior rather than implementation?
- Does it fit the current session model, tags, and `DifyWorld` usage?
- Should an existing step be reused instead of adding a new one?
- Are locators user-facing and assertions web-first?
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
- Does it document or implement behavior that differs from the real hooks or configuration?
Lead findings with correctness, flake risk, and architecture drift.
## References
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)

View File

@ -0,0 +1,4 @@
interface:
display_name: "E2E Cucumber + Playwright"
short_description: "Write and review Dify E2E scenarios."
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."

View File

@ -0,0 +1,93 @@
# Cucumber Best Practices For Dify E2E
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
Official sources:
- https://cucumber.io/docs/guides/10-minute-tutorial/
- https://cucumber.io/docs/cucumber/step-definitions/
- https://cucumber.io/docs/cucumber/cucumber-expressions/
## What Matters Most
### 1. Treat scenarios as executable specifications
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
Apply it like this:
- write what the user does and what should happen
- avoid UI-internal wording such as selector details, DOM structure, or component names
- keep language concrete enough that the scenario reads like living documentation
### 2. Keep scenarios focused
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
In Dify's suite, this means:
- one capability-focused scenario per feature path
- no long setup chains when existing bootstrap or reusable steps already cover them
- no hidden dependency on another scenario's side effects
### 3. Reuse steps, but only when behavior really matches
Good reuse reduces duplication. Bad reuse hides meaning.
Prefer reuse when:
- the user action is genuinely the same
- the expected outcome is genuinely the same
- the wording stays natural across features
Write a new step when:
- the behavior is materially different
- reusing the old wording would make the scenario misleading
- a supposedly generic step would become an implementation-detail wrapper
### 4. Prefer Cucumber Expressions
Use Cucumber Expressions for parameters unless regex is clearly necessary.
Common examples:
- `{string}` for labels, names, and visible text
- `{int}` for counts
- `{float}` for decimal values
- `{word}` only when the value is truly a single token
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
### 5. Keep step definitions thin and meaningful
Step definitions are glue between Gherkin and automation, not a second abstraction language.
For Dify:
- type `this` as `DifyWorld`
- use `async function`
- keep each step to one user-visible action or assertion
- rely on `DifyWorld` and existing support code for shared context
- avoid leaking cross-scenario state
### 6. Use tags intentionally
Tags should communicate run scope or session semantics, not become ad hoc metadata.
In Dify's current suite:
- capability tags group related scenarios
- `@unauthenticated` changes session behavior
- `@authenticated` is descriptive/selective, not a behavior switch by itself
- `@fresh` belongs to reset/full-install flows only
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
## Review Questions
- Does the scenario read like a real example of product behavior?
- Are the steps behavior-oriented instead of implementation-oriented?
- Is a reused step still truthful in this feature?
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
- Would a new reader understand the outcome without opening the step-definition file?

View File

@ -0,0 +1,96 @@
# Playwright Best Practices For Dify E2E
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
Official sources:
- https://playwright.dev/docs/best-practices
- https://playwright.dev/docs/locators
- https://playwright.dev/docs/test-assertions
- https://playwright.dev/docs/browser-contexts
## What Matters Most
### 1. Keep scenarios isolated
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
Apply it like this:
- do not depend on another scenario having run first
- do not persist ad hoc scenario state outside `DifyWorld`
- do not couple ordinary scenarios to `@fresh` behavior
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
### 2. Prefer user-facing locators
Playwright recommends built-in locators that reflect what users perceive on the page.
Preferred order in this repository:
1. `getByRole`
2. `getByLabel`
3. `getByPlaceholder`
4. `getByText`
5. `getByTestId` when an explicit test contract is the most stable option
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
Also remember:
- repeated content usually needs scoping to a stable container
- exact text matching is often too brittle when role/name or label already exists
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
### 3. Use web-first assertions
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
Prefer:
- `await expect(page).toHaveURL(...)`
- `await expect(locator).toBeVisible()`
- `await expect(locator).toBeHidden()`
- `await expect(locator).toBeEnabled()`
- `await expect(locator).toHaveText(...)`
Avoid:
- `expect(await locator.isVisible()).toBe(true)`
- custom polling loops for DOM state
- `waitForTimeout` as synchronization
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
### 4. Let actions wait for actionability
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
Good pattern:
- assert a meaningful visible state when that is part of the behavior
- then click/fill/select via locator APIs
Bad pattern:
- stack arbitrary waits before every action
- wait on unstable implementation details instead of the visible state the user cares about
### 5. Match debugging to the current suite
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
- full-page screenshots
- page HTML
- console errors
- page errors
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
## Review Questions
- Would this locator survive DOM refactors that do not change user-visible behavior?
- Is this assertion using Playwright's retrying semantics?
- Is any explicit wait masking a real readiness problem?
- Does this code preserve per-scenario isolation?
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?

View File

@ -0,0 +1 @@
../../.agents/skills/e2e-cucumber-playwright

View File

@ -6,14 +6,7 @@ on:
- "main" - "main"
paths: paths:
- api/Dockerfile - api/Dockerfile
- web/docker/**
- web/Dockerfile - web/Dockerfile
- packages/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
- .nvmrc
concurrency: concurrency:
group: docker-build-${{ github.head_ref || github.run_id }} group: docker-build-${{ github.head_ref || github.run_id }}

View File

@ -92,6 +92,7 @@ jobs:
vdb: vdb:
- 'api/core/rag/datasource/**' - 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**' - 'api/tests/integration_tests/vdb/**'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml' - '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh' - '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example' - 'docker/.env.example'

View File

@ -89,7 +89,7 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB) # - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py # run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -81,12 +81,12 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB) # - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py # run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: | run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \ uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \ api/providers/vdb/vdb-chroma/tests/integration_tests \
api/tests/integration_tests/vdb/pgvector \ api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/tests/integration_tests/vdb/qdrant \ api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/tests/integration_tests/vdb/weaviate api/providers/vdb/vdb-weaviate/tests/integration_tests

View File

@ -69,8 +69,6 @@ ignore = [
"FURB152", # math-constant "FURB152", # math-constant
"UP007", # non-pep604-annotation "UP007", # non-pep604-annotation
"UP032", # f-string "UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default "B006", # mutable-argument-default
"B007", # unused-loop-control-variable "B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg "B026", # star-arg-unpacking-after-keyword-arg
@ -84,7 +82,6 @@ ignore = [
"SIM102", # collapsible-if "SIM102", # collapsible-if
"SIM103", # needless-bool "SIM103", # needless-bool
"SIM105", # suppressible-exception "SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp "SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
@ -93,29 +90,16 @@ ignore = [
] ]
[lint.per-file-ignores] [lint.per-file-ignores]
"__init__.py" = [
"F401", # unused-import
"F811", # redefined-while-unused
]
"configs/*" = [ "configs/*" = [
"N802", # invalid-function-name "N802", # invalid-function-name
] ]
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [ "libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name "N803", # invalid-argument-name
] ]
"tests/*" = [ "tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests, "T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently) "S110", # allow ignoring exceptions in tests code (currently)
] ]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"] [lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse." msg = "Use Pydantic payload/query models instead of reqparse."

View File

@ -21,8 +21,9 @@ RUN apt-get update \
# for building gmpy2 # for building gmpy2
libmpfr-dev libmpc-dev libmpfr-dev libmpc-dev
# Install Python dependencies # Install Python dependencies (workspace members under providers/vdb/)
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev RUN uv sync --locked --no-dev
# production stage # production stage

View File

@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
click.echo(click.style("No dataset collection bindings found.", fg="red")) click.echo(click.style("No dataset collection bindings found.", fg="red"))
return return
import qdrant_client import qdrant_client
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings: for binding in bindings:
if dify_config.QDRANT_URL is None: if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.") raise ValueError("Qdrant URL is required.")

View File

@ -1,4 +1,3 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field from pydantic import Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
default="public", default="public",
) )
HOLOGRES_TOKENIZER: TokenizerType = Field( HOLOGRES_TOKENIZER: str = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').", description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba", default="jieba",
) )
HOLOGRES_DISTANCE_METHOD: DistanceType = Field( HOLOGRES_DISTANCE_METHOD: str = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').", description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine", default="Cosine",
) )
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field( HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').", description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq", default="rabitq",
) )

View File

@ -1,5 +1,7 @@
"""Configuration for InterSystems IRIS vector database.""" """Configuration for InterSystems IRIS vector database."""
from typing import Any
from pydantic import Field, PositiveInt, model_validator from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate IRIS configuration values. """Validate IRIS configuration values.
Args: Args:

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel): class AdvancedPromptTemplateQuery(BaseModel):
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
prompt_args: AdvancedPromptTemplateArgs = {
return AdvancedPromptTemplateService.get_prompt(args.model_dump()) "app_mode": args.app_mode,
"model_mode": args.model_mode,
"model_name": args.model_name,
"has_context": args.has_context,
}
return AdvancedPromptTemplateService.get_prompt(prompt_args)

View File

@ -26,13 +26,13 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
class MCPServerCreatePayload(BaseModel): class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description") description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration") parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel): class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID") id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description") description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration") parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status") status: str | None = Field(default=None, description="Server status")

View File

@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session, session=session,
app_model=app_model, app_model=app_model,
@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session, session=session,
app_model=app_model, app_model=app_model,

View File

@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None: def _build_backstage_input_url(form_token: str | None) -> str | None:
@ -214,7 +214,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
Get advanced chat app workflow run list Get advanced chat app workflow run list
""" """
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True) args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
# Default to DEBUGGING if not specified # Default to DEBUGGING if not specified
triggered_from = ( triggered_from = (
@ -356,7 +360,11 @@ class WorkflowRunListApi(Resource):
Get workflow run list Get workflow run list
""" """
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True) args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
# Default to DEBUGGING for workflow if not specified (backward compatibility) # Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = ( triggered_from = (

View File

@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
node_id = args.node_id node_id = args.node_id
with sessionmaker(db.engine).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get webhook trigger for this app and node # Get webhook trigger for this app and node
webhook_trigger = session.scalar( webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger) select(WorkflowWebhookTrigger)
@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
with sessionmaker(db.engine).begin() as session: with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
# Get all triggers for this app using select API # Get all triggers for this app using select API
triggers = ( triggers = (
session.execute( session.execute(

View File

@ -1,7 +1,10 @@
import logging
import flask_login import flask_login
from flask import make_response, request from flask import make_response, request
from flask_restx import Resource from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from werkzeug.exceptions import Unauthorized
import services import services
from configs import dify_config from configs import dify_config
@ -42,12 +45,13 @@ from libs.token import (
) )
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.errors.account import AccountRegisterError from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase): class LoginPayload(LoginPayloadBase):
@ -91,10 +95,12 @@ class LoginApi(Resource):
normalized_email = request_email.lower() normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError() raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email) is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit: if is_login_error_rate_limit:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
raise EmailPasswordLoginLimitError() raise EmailPasswordLoginLimitError()
invite_token = args.invite_token invite_token = args.invite_token
@ -110,14 +116,20 @@ class LoginApi(Resource):
invitee_email = data.get("email") if data else None invitee_email = data.get("email") if data else None
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email: if invitee_email_normalized != normalized_email:
_log_console_login_failure(
email=normalized_email,
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
)
raise InvalidEmailError() raise InvalidEmailError()
account = _authenticate_account_with_case_fallback( account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token request_email, normalized_email, args.password, invite_token
) )
except services.errors.account.AccountLoginError: except services.errors.account.AccountLoginError:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() raise AccountBannedError()
except services.errors.account.AccountPasswordError as exc: except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email) AccountService.add_login_error_rate_limit(normalized_email)
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() from exc raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource):
token_data = AccountService.get_email_code_login_data(args.token) token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None: if token_data is None:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError() raise InvalidTokenError()
token_email = token_data.get("email") token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email: if normalized_token_email != user_email:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError() raise InvalidEmailError()
if token_data["code"] != args.code: if token_data["code"] != args.code:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError() raise EmailCodeError()
AccountService.revoke_email_code_login_token(args.token) AccountService.revoke_email_code_login_token(args.token)
try: try:
account = _get_account_with_case_fallback(original_email) account = _get_account_with_case_fallback(original_email)
except Unauthorized as exc:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
except AccountRegisterError: except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError() raise AccountInFreezeError()
if account: if account:
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource):
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace() raise NotAllowedCreateWorkspace()
except AccountRegisterError: except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError() raise AccountInFreezeError()
except WorkspacesLimitExceededError: except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback(
if original_email == normalized_email: if original_email == normalized_email:
raise raise
return AccountService.authenticate(normalized_email, password, invite_token) return AccountService.authenticate(normalized_email, password, invite_token)
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Console login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import TypedDict from typing import TypedDict
from flask import request from flask import request
@ -13,6 +14,14 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US" _FALLBACK_LANG = "en-US"
class NotificationLangContent(TypedDict, total=False):
lang: str
title: str
subtitle: str
body: str
titlePicUrl: str
class NotificationItemDict(TypedDict): class NotificationItemDict(TypedDict):
notification_id: str | None notification_id: str | None
frequency: str | None frequency: str | None
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
notifications: list[NotificationItemDict] notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict: def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
"""Return the single LangContent for *lang*, falling back to English.""" """Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) return (
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
)
class DismissNotificationPayload(BaseModel): class DismissNotificationPayload(BaseModel):
@ -71,7 +82,7 @@ class NotificationApi(Resource):
notifications: list[NotificationItemDict] = [] notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []: for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {} contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang) lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = { item: NotificationItemDict = {
"notification_id": notification.get("notificationId"), "notification_id": notification.get("notificationId"),

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Literal from typing import Any, Literal
import pytz import pytz
from flask import request from flask import request
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse) register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict: def _serialize_account(account) -> dict[str, Any]:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")

View File

@ -20,7 +20,7 @@ from models.account import AccountStatus
from models.dataset import RateLimitLog from models.dataset import RateLimitLog
from models.model import DifySetup from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService from services.operation_service import OperationService, UtmInfo
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
utm_info = request.cookies.get("utm_info") utm_info = request.cookies.get("utm_info")
if utm_info: if utm_info:
utm_info_dict: dict = json.loads(utm_info) utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict) OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -2,7 +2,7 @@ from typing import Any, Union
from flask import Response from flask import Response
from flask_restx import Resource from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity from graphon.variables.input_entities import VariableEntity, VariableEntityType
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
except ValidationError as e: except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]: def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects""" """Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form] return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict) -> VariableEntity: def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
"""Create a single VariableEntity from raw form item""" """Create a single VariableEntity from raw form item"""
variable_type = item.get("type", "") or list(item.keys())[0] variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type] try:
variable_type = VariableEntityType(variable_type_raw)
except ValueError as e:
raise MCPRequestError(
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
) from e
variable = item[variable_type_raw]
return VariableEntity( return VariableEntity(
type=variable_type, type=variable_type,
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
json_schema=variable.get("json_schema"), json_schema=variable.get("json_schema"),
) )
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification: def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request""" """Parse and validate MCP request"""
try: try:
return mcp_types.ClientRequest.model_validate(args) return mcp_types.ClientRequest.model_validate(args)

View File

@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict: def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
"""Marshal a single segment and enrich it with summary content.""" """Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]: def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
"""Marshal multiple segments and enrich them with summary content (batch query).""" """Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments] segment_ids = [segment.id for segment in segments]
summaries: dict = {} summaries: dict[str, str | None] = {}
if segment_ids: if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()} summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = [] result: list[dict[str, Any]] = []
for segment in segments: for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id) segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict) result.append(segment_dict)
return result return result

View File

@ -5,6 +5,7 @@ Web App Human Input Form APIs.
import json import json
import logging import logging
from datetime import datetime from datetime import datetime
from typing import Any, NotRequired, TypedDict
from flask import Response, request from flask import Response, request
from flask_restx import Resource from flask_restx import Resource
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
return int(value.timestamp()) return int(value.timestamp())
class FormDefinitionPayload(TypedDict):
form_content: Any
inputs: Any
resolved_default_values: dict[str, str]
user_actions: Any
expiration_time: int
site: NotRequired[dict]
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response: def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response.""" """Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump() definition_payload = form.get_definition().model_dump()
payload = { payload: FormDefinitionPayload = {
"form_content": definition_payload["rendered_content"], "form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"], "inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]), "resolved_default_values": _stringify_default_values(definition_payload["default_values"]),

View File

@ -1,7 +1,10 @@
import logging
from flask import make_response, request from flask import make_response, request
from flask_restx import Resource from flask_restx import Resource
from jwt import InvalidTokenError from jwt import InvalidTokenError
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Unauthorized
import services import services
from configs import dify_config from configs import dify_config
@ -20,7 +23,7 @@ from controllers.console.wraps import (
) )
from controllers.web import web_ns from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token from controllers.web.wraps import decode_jwt_token
from libs.helper import EmailStr from libs.helper import EmailStr, extract_remote_ip
from libs.passport import PassportService from libs.passport import PassportService
from libs.password import valid_password from libs.password import valid_password
from libs.token import ( from libs.token import (
@ -29,9 +32,11 @@ from libs.token import (
) )
from services.account_service import AccountService from services.account_service import AccountService
from services.app_service import AppService from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase): class LoginPayload(LoginPayloadBase):
@field_validator("password") @field_validator("password")
@ -76,14 +81,18 @@ class LoginApi(Resource):
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
payload = LoginPayload.model_validate(web_ns.payload or {}) payload = LoginPayload.model_validate(web_ns.payload or {})
normalized_email = payload.email.lower()
try: try:
account = WebAppAuthService.authenticate(payload.email, payload.password) account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError: except services.errors.account.AccountLoginError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() raise AccountBannedError()
except services.errors.account.AccountPasswordError: except services.errors.account.AccountPasswordError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() raise AuthenticationFailedError()
except services.errors.account.AccountNotFoundError: except services.errors.account.AccountNotFoundError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError() raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account) token = WebAppAuthService.login(account=account)
@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource):
token_data = WebAppAuthService.get_email_code_login_data(payload.token) token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None: if token_data is None:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError() raise InvalidTokenError()
token_email = token_data.get("email") token_email = token_data.get("email")
if not isinstance(token_email, str): if not isinstance(token_email, str):
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError() raise InvalidEmailError()
normalized_token_email = token_email.lower() normalized_token_email = token_email.lower()
if normalized_token_email != user_email: if normalized_token_email != user_email:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError() raise InvalidEmailError()
if token_data["code"] != payload.code: if token_data["code"] != payload.code:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError() raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(payload.token) WebAppAuthService.revoke_email_code_login_token(payload.token)
account = WebAppAuthService.get_user_through_email(token_email) try:
account = WebAppAuthService.get_user_through_email(token_email)
except Unauthorized as exc:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
if not account: if not account:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError() raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account) token = WebAppAuthService.login(account=account)
@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource):
response = make_response({"result": "success", "data": {"access_token": token}}) response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response return response
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Web login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@ -1,5 +1,6 @@
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Any
from flask import make_response, request from flask import make_response, request
from flask_restx import Resource from flask_restx import Resource
@ -103,21 +104,23 @@ class PassportResource(Resource):
return response return response
def decode_enterprise_webapp_user_id(jwt_token: str | None): def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
""" """
Decode the enterprise user session from the Authorization header. Decode the enterprise user session from the Authorization header.
""" """
if not jwt_token: if not jwt_token:
return None return None
decoded = PassportService().verify(jwt_token) decoded: dict[str, Any] = PassportService().verify(jwt_token)
source = decoded.get("token_source") source = decoded.get("token_source")
if not source or source != "webapp_login_token": if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded return decoded
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType): def exchange_token_for_existing_web_user(
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
):
""" """
Exchange a token for an existing web user session. Exchange a token for an existing web user session.
""" """

View File

@ -1,4 +1,4 @@
from typing import cast from typing import Any, cast
from flask_restx import fields, marshal, marshal_with from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select from sqlalchemy import select
@ -113,12 +113,12 @@ class AppSiteInfo:
} }
def serialize_site(site: Site) -> dict: def serialize_site(site: Site) -> dict[str, Any]:
"""Serialize Site model using the same schema as AppSiteApi.""" """Serialize Site model using the same schema as AppSiteApi."""
return cast(dict, marshal(site, AppSiteApi.site_fields)) return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict: def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo) app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields)) return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))

View File

@ -138,7 +138,9 @@ class DatasetConfigManager:
) )
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for dataset feature Validate and set defaults for dataset feature
@ -172,7 +174,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod @classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict): def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
""" """
Extract dataset config for legacy compatibility Extract dataset config for legacy compatibility

View File

@ -108,7 +108,7 @@ class ModelConfigManager:
return dict(config), ["model"] return dict(config), ["model"]
@classmethod @classmethod
def validate_model_completion_params(cls, cp: dict): def validate_model_completion_params(cls, cp: dict[str, Any]):
# model.completion_params # model.completion_params
if not isinstance(cp, dict): if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type") raise ValueError("model.completion_params must be of object type")

View File

@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
) )
@classmethod @classmethod
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate pre_prompt and set defaults for prompt feature Validate pre_prompt and set defaults for prompt feature
depending on the config['model'] depending on the config['model']
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod @classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict): def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
""" """
Validate post_prompt and set defaults for prompt feature Validate post_prompt and set defaults for prompt feature

View File

@ -1,5 +1,5 @@
import re import re
from typing import cast from typing import Any, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.input_entities import VariableEntity, VariableEntityType
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
return variable_entities, external_data_variables return variable_entities, external_data_variables
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for user input form Validate and set defaults for user input form
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
return config, related_config_keys return config, related_config_keys
@classmethod @classmethod
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for user input form Validate and set defaults for user input form
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
return config, ["user_input_form"] return config, ["user_input_form"]
@classmethod @classmethod
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_external_data_tools_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for external data fetch feature Validate and set defaults for external data fetch feature

View File

@ -30,7 +30,7 @@ class FileUploadConfigManager:
return FileUploadConfig.model_validate(file_upload_dict) return FileUploadConfig.model_validate(file_upload_dict)
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for file upload feature Validate and set defaults for file upload feature

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, ValidationError from pydantic import BaseModel, ConfigDict, Field, ValidationError
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
class MoreLikeThisConfigManager: class MoreLikeThisConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> bool: def convert(cls, config: dict[str, Any]) -> bool:
""" """
Convert model config to model config Convert model config to model config
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
return AppConfigModel.model_validate(validated_config).more_like_this.enabled return AppConfigModel.model_validate(validated_config).more_like_this.enabled
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
try: try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"] return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError: except ValidationError:

View File

@ -1,6 +1,9 @@
from typing import Any
class OpeningStatementConfigManager: class OpeningStatementConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> tuple[str, list]: def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
""" """
Convert model config to model config Convert model config to model config
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
return opening_statement, suggested_questions_list return opening_statement, suggested_questions_list
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for opening statement feature Validate and set defaults for opening statement feature

View File

@ -1,6 +1,9 @@
from typing import Any
class RetrievalResourceConfigManager: class RetrievalResourceConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> bool: def convert(cls, config: dict[str, Any]) -> bool:
show_retrieve_source = False show_retrieve_source = False
retriever_resource_dict = config.get("retriever_resource") retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict: if retriever_resource_dict:
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
return show_retrieve_source return show_retrieve_source
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for retriever resource feature Validate and set defaults for retriever resource feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SpeechToTextConfigManager: class SpeechToTextConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> bool: def convert(cls, config: dict[str, Any]) -> bool:
""" """
Convert model config to model config Convert model config to model config
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
return speech_to_text return speech_to_text
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for speech to text feature Validate and set defaults for speech to text feature

View File

@ -1,6 +1,9 @@
from typing import Any
class SuggestedQuestionsAfterAnswerConfigManager: class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> bool: def convert(cls, config: dict[str, Any]) -> bool:
""" """
Convert model config to model config Convert model config to model config
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
return suggested_questions_after_answer return suggested_questions_after_answer
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for suggested questions feature Validate and set defaults for suggested questions feature

View File

@ -1,9 +1,11 @@
from typing import Any
from core.app.app_config.entities import TextToSpeechEntity from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager: class TextToSpeechConfigManager:
@classmethod @classmethod
def convert(cls, config: dict): def convert(cls, config: dict[str, Any]):
""" """
Convert model config to model config Convert model config to model config
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
return text_to_speech return text_to_speech
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
""" """
Validate and set defaults for text to speech feature Validate and set defaults for text to speech feature

View File

@ -1,5 +1,5 @@
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse _blocking_response_type = WorkflowAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump()) return dict(blocking_response.model_dump())
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response

View File

@ -1,3 +1,5 @@
from typing import Any
from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
return pipeline_config return pipeline_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: def config_validate(
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> dict[str, Any]:
""" """
Validate for pipeline config Validate for pipeline config

View File

@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
user_id: str, user_id: str,
all_files: list, all_files: list,
datasource_info: Mapping[str, Any], datasource_info: Mapping[str, Any],
next_page_parameters: dict | None = None, next_page_parameters: dict[str, Any] | None = None,
): ):
""" """
Get files in a folder. Get files in a folder.

View File

@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
node_type: str node_type: str
title: str title: str
created_at: int created_at: int
extras: dict = Field(default_factory=dict) extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {} metadata: Mapping = {}
inputs: Mapping = {} inputs: Mapping = {}
inputs_truncated: bool = False inputs_truncated: bool = False
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str title: str
index: int index: int
created_at: int created_at: int
extras: dict = Field(default_factory=dict) extras: dict[str, Any] = Field(default_factory=dict)
event: StreamEvent = StreamEvent.ITERATION_NEXT event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str workflow_run_id: str
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None outputs: Mapping | None = None
outputs_truncated: bool = False outputs_truncated: bool = False
created_at: int created_at: int
extras: dict | None = None extras: dict[str, Any] | None = None
inputs: Mapping | None = None inputs: Mapping | None = None
inputs_truncated: bool = False inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus status: WorkflowNodeExecutionStatus
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
node_type: str node_type: str
title: str title: str
created_at: int created_at: int
extras: dict = Field(default_factory=dict) extras: dict[str, Any] = Field(default_factory=dict)
metadata: Mapping = {} metadata: Mapping = {}
inputs: Mapping = {} inputs: Mapping = {}
inputs_truncated: bool = False inputs_truncated: bool = False
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None outputs: Mapping | None = None
outputs_truncated: bool = False outputs_truncated: bool = False
created_at: int created_at: int
extras: dict | None = None extras: dict[str, Any] | None = None
inputs: Mapping | None = None inputs: Mapping | None = None
inputs_truncated: bool = False inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus status: WorkflowNodeExecutionStatus

View File

@ -14,7 +14,7 @@ class DatasourceApiEntity(BaseModel):
description: I18nObject description: I18nObject
parameters: list[DatasourceParameter] | None = None parameters: list[DatasourceParameter] | None = None
labels: list[str] = Field(default_factory=list) labels: list[str] = Field(default_factory=list)
output_schema: dict | None = None output_schema: dict[str, Any] | None = None
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
@ -30,7 +30,7 @@ class DatasourceProviderApiEntityDict(TypedDict):
icon: str | dict icon: str | dict
label: I18nObjectDict label: I18nObjectDict
type: str type: str
team_credentials: dict | None team_credentials: dict[str, Any] | None
is_team_authorization: bool is_team_authorization: bool
allow_delete: bool allow_delete: bool
datasources: list[Any] datasources: list[Any]
@ -45,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel):
icon: str | dict icon: str | dict
label: I18nObject # label label: I18nObject # label
type: str type: str
masked_credentials: dict | None = None masked_credentials: dict[str, Any] | None = None
original_credentials: dict | None = None original_credentials: dict[str, Any] | None = None
is_team_authorization: bool = False is_team_authorization: bool = False
allow_delete: bool = True allow_delete: bool = True
plugin_id: str | None = Field(default="", description="The plugin id of the datasource") plugin_id: str | None = Field(default="", description="The plugin id of the datasource")

View File

@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel):
identity: DatasourceIdentity identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list) parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource") description: I18nObject = Field(..., description="The label of the datasource")
output_schema: dict | None = None output_schema: dict[str, Any] | None = None
@field_validator("parameters", mode="before") @field_validator("parameters", mode="before")
@classmethod @classmethod
@ -192,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel):
time_cost: float = Field(..., description="The time cost of the tool invoke") time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None error: str | None = None
tool_config: dict | None = None tool_config: dict[str, Any] | None = None
@classmethod @classmethod
def empty(cls) -> DatasourceInvokeMeta: def empty(cls) -> DatasourceInvokeMeta:
@ -242,7 +242,7 @@ class OnlineDocumentPage(BaseModel):
page_id: str = Field(..., description="The page id") page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title") page_name: str = Field(..., description="The page title")
page_icon: dict | None = Field(None, description="The page icon") page_icon: dict[str, Any] | None = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page") type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time") last_edited_time: str = Field(..., description="The last edited time")
parent_id: str | None = Field(None, description="The parent page id") parent_id: str | None = Field(None, description="The parent page id")
@ -301,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel):
Get website crawl request Get website crawl request
""" """
crawl_parameters: dict = Field(..., description="The crawl parameters") crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters")
class WebSiteInfoDetail(BaseModel): class WebSiteInfoDetail(BaseModel):
@ -358,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel):
bucket: str | None = Field(None, description="The file bucket") bucket: str | None = Field(None, description="The file bucket")
files: list[OnlineDriveFile] = Field(..., description="The file list") files: list[OnlineDriveFile] = Field(..., description="The file list")
is_truncated: bool = Field(False, description="Whether the result is truncated") is_truncated: bool = Field(False, description="Whether the result is truncated")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesRequest(BaseModel): class OnlineDriveBrowseFilesRequest(BaseModel):
@ -369,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
bucket: str | None = Field(None, description="The file bucket") bucket: str | None = Field(None, description="The file bucket")
prefix: str = Field(..., description="The parent folder ID") prefix: str = Field(..., description="The parent folder ID")
max_keys: int = Field(20, description="Page size for pagination") max_keys: int = Field(20, description="Page size for pagination")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesResponse(BaseModel): class OnlineDriveBrowseFilesResponse(BaseModel):

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
id: str id: str
position: int position: int
data_source_type: str data_source_type: str
data_source_info: dict | None = None data_source_info: dict[str, Any] | None = None
name: str name: str
indexing_status: str indexing_status: str
error: str | None = None error: str | None = None

View File

@ -6,6 +6,7 @@ import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterator, Sequence from collections.abc import Iterator, Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import ( from graphon.model_runtime.entities.provider_entities import (
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
return ModelProviderFactory(model_runtime=self._bound_model_runtime) return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id) return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
""" """
Get current credentials. Get current credentials.
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
return session.execute(stmt).scalar_one_or_none() return session.execute(stmt).scalar_one_or_none()
def _get_specific_provider_credential(self, credential_id: str) -> dict | None: def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
""" """
Get a specific provider credential by ID. Get a specific provider credential by ID.
:param credential_id: Credential ID :param credential_id: Credential ID
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderCredential.id != exclude_id) stmt = stmt.where(ProviderCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None return session.execute(stmt).scalar_one_or_none() is not None
def get_provider_credential(self, credential_id: str | None = None) -> dict | None: def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
""" """
Get provider credentials. Get provider credentials.
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
else [], else [],
) )
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
""" """
Validate custom credentials. Validate custom credentials.
:param credentials: provider credentials :param credentials: provider credentials
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
provider_names.append(model_provider_id.provider_name) provider_names.append(model_provider_id.provider_name)
return provider_names return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None): def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
""" """
Add custom provider credentials. Add custom provider credentials.
:param credentials: provider credentials :param credentials: provider credentials
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
def update_provider_credential( def update_provider_credential(
self, self,
credentials: dict, credentials: dict[str, Any],
credential_id: str, credential_id: str,
credential_name: str | None, credential_name: str | None,
): ):
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
def _get_specific_custom_model_credential( def _get_specific_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str self, model_type: ModelType, model: str, credential_id: str
) -> dict | None: ) -> dict[str, Any] | None:
""" """
Get a specific provider credential by ID. Get a specific provider credential by ID.
:param credential_id: Credential ID :param credential_id: Credential ID
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id) stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> dict[str, Any] | None:
""" """
Get custom model credentials. Get custom model credentials.
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
self, self,
model_type: ModelType, model_type: ModelType,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
credential_id: str = "", credential_id: str = "",
session: Session | None = None, session: Session | None = None,
): ):
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session) return _validate(new_session)
def create_custom_model_credential( def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
) -> None: ) -> None:
""" """
Create a custom model credential. Create a custom model credential.
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
raise raise
def update_custom_model_credential( def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str self,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
credential_id: str,
) -> None: ) -> None:
""" """
Update a custom model credential. Update a custom model credential.
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM # Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
) -> AIModelEntity | None:
""" """
Get model schema Get model schema
""" """
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
""" """
Obfuscated credentials. Obfuscated credentials.

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Union from typing import Any, Union
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
enabled: bool enabled: bool
current_quota_type: ProviderQuotaType | None = None current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = [] quota_configurations: list[QuotaConfiguration] = []
credentials: dict | None = None credentials: dict[str, Any] | None = None
class CustomProviderConfiguration(BaseModel): class CustomProviderConfiguration(BaseModel):
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
Model class for provider custom configuration. Model class for provider custom configuration.
""" """
credentials: dict credentials: dict[str, Any]
current_credential_id: str | None = None current_credential_id: str | None = None
current_credential_name: str | None = None current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] = [] available_credentials: list[CredentialConfiguration] = []
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
model: str model: str
model_type: ModelType model_type: ModelType
credentials: dict | None credentials: dict[str, Any] | None
current_credential_id: str | None = None current_credential_id: str | None = None
current_credential_name: str | None = None current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = [] available_model_credentials: list[CredentialConfiguration] = []
@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
id: str id: str
name: str name: str
credentials: dict credentials: dict[str, Any]
credential_source_type: str | None = None credential_source_type: str | None = None
credential_id: str | None = None credential_id: str | None = None

View File

@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory: class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict): def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class( self.__extension_instance = extension_class(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
) )
@classmethod @classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict): def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
""" """
Validate the incoming form config data. Validate the incoming form config data.

View File

@ -1,6 +1,7 @@
import json import json
from enum import StrEnum from enum import StrEnum
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -15,7 +16,7 @@ class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> dict | None: def get(self) -> dict[str, Any] | None:
""" """
Get cached model provider credentials. Get cached model provider credentials.
@ -33,7 +34,7 @@ class ProviderCredentialsCache:
else: else:
return None return None
def set(self, credentials: dict): def set(self, credentials: dict[str, Any]):
""" """
Cache model provider credentials. Cache model provider credentials.

View File

@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC):
"""Generate cache key based on subclass implementation""" """Generate cache key based on subclass implementation"""
pass pass
def get(self) -> dict | None: def get(self) -> dict[str, Any] | None:
"""Get cached provider credentials""" """Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key) cached_credentials = redis_client.get(self.cache_key)
if cached_credentials: if cached_credentials:
@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
class NoOpProviderCredentialCache: class NoOpProviderCredentialCache:
"""No-op provider credential cache""" """No-op provider credential cache"""
def get(self) -> dict | None: def get(self) -> dict[str, Any] | None:
"""Get cached provider credentials""" """Get cached provider credentials"""
return None return None

View File

@ -1,6 +1,7 @@
import json import json
from enum import StrEnum from enum import StrEnum
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
@ -18,7 +19,7 @@ class ToolParameterCache:
f":identity_id:{identity_id}" f":identity_id:{identity_id}"
) )
def get(self) -> dict | None: def get(self) -> dict[str, Any] | None:
""" """
Get cached model provider credentials. Get cached model provider credentials.
@ -36,7 +37,7 @@ class ToolParameterCache:
else: else:
return None return None
def set(self, parameters: dict): def set(self, parameters: dict[str, Any]):
"""Cache model provider credentials.""" """Cache model provider credentials."""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) redis_client.setex(self.cache_key, 86400, json.dumps(parameters))

View File

@ -735,7 +735,9 @@ class IndexingRunner:
@staticmethod @staticmethod
def _update_document_index_status( def _update_document_index_status(
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None document_id: str,
after_indexing_status: IndexingStatus,
extra_update_params: dict[Any, Any] | None = None,
): ):
""" """
Update the document indexing status. Update the document indexing status.
@ -762,7 +764,7 @@ class IndexingRunner:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def _update_segments_by_document(dataset_document_id: str, update_params: dict): def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]):
""" """
Update the document segment by document id. Update the document segment by document id.
""" """

View File

@ -200,7 +200,7 @@ def _handle_native_json_schema(
provider: str, provider: str,
model_schema: AIModelEntity, model_schema: AIModelEntity,
structured_output_schema: Mapping, structured_output_schema: Mapping,
model_parameters: dict, model_parameters: dict[str, Any],
rules: list[ParameterRule], rules: list[ParameterRule],
): ):
""" """
@ -224,7 +224,7 @@ def _handle_native_json_schema(
return model_parameters return model_parameters
def _set_response_format(model_parameters: dict, rules: list): def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]):
""" """
Set the appropriate response format parameter based on model rules. Set the appropriate response format parameter based on model rules.
@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
return {"schema": processed_schema, "name": "llm_response"} return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict): def remove_additional_properties(schema: dict[str, Any]):
""" """
Remove additionalProperties fields from JSON schema. Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property. Used for models like Gemini that don't support this property.

View File

@ -77,7 +77,7 @@ class ModelInstance:
@staticmethod @staticmethod
def _get_load_balancing_manager( def _get_load_balancing_manager(
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
) -> Optional["LBModelManager"]: ) -> Optional["LBModelManager"]:
""" """
Get load balancing model credentials Get load balancing model credentials
@ -115,7 +115,7 @@ class ModelInstance:
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: Sequence[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict | None = None, model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None, stop: list[str] | None = None,
stream: Literal[True] = True, stream: Literal[True] = True,
@ -126,7 +126,7 @@ class ModelInstance:
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict | None = None, model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None, stop: list[str] | None = None,
stream: Literal[False] = False, stream: Literal[False] = False,
@ -137,7 +137,7 @@ class ModelInstance:
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict | None = None, model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None, stop: list[str] | None = None,
stream: bool = True, stream: bool = True,
@ -147,7 +147,7 @@ class ModelInstance:
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: Sequence[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict | None = None, model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None, stop: Sequence[str] | None = None,
stream: bool = True, stream: bool = True,
@ -528,7 +528,7 @@ class LBModelManager:
model_type: ModelType, model_type: ModelType,
model: str, model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration], load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: dict | None = None, managed_credentials: dict[str, Any] | None = None,
): ):
""" """
Load balancing model manager Load balancing model manager

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel): class ModerationInputParams(BaseModel):
app_id: str = "" app_id: str = ""
inputs: dict = Field(default_factory=dict) inputs: dict[str, Any] = Field(default_factory=dict)
query: str = "" query: str = ""
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
name: str = "api" name: str = "api"
@classmethod @classmethod
def validate_config(cls, tenant_id: str, config: dict): def validate_config(cls, tenant_id: str, config: dict[str, Any]):
""" """
Validate the incoming form config data. Validate the incoming form config data.
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
if not extension: if not extension:
raise ValueError("API-based Extension not found. Please check it again.") raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False flagged = False
preset_response = "" preset_response = ""
if self.config is None: if self.config is None:
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
) )
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict): def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
if self.config is None: if self.config is None:
raise ValueError("The config is not set.") raise ValueError("The config is not set.")
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import StrEnum, auto from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
flagged: bool = False flagged: bool = False
action: ModerationAction action: ModerationAction
preset_response: str = "" preset_response: str = ""
inputs: dict = Field(default_factory=dict) inputs: dict[str, Any] = Field(default_factory=dict)
query: str = "" query: str = ""
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
module: ExtensionModule = ExtensionModule.MODERATION module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None): def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
super().__init__(tenant_id, config) super().__init__(tenant_id, config)
self.app_id = app_id self.app_id = app_id
@classmethod @classmethod
@abstractmethod @abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
""" """
Validate the incoming form config data. Validate the incoming form config data.
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
""" """
Moderation for inputs. Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review After the user inputs, this method will be called to perform sensitive content review
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool): def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
# inputs_config # inputs_config
inputs_config = config.get("inputs_config") inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict): if not isinstance(inputs_config, dict):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.extension.extensible import ExtensionModule from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension from extensions.ext_code_based_extension import code_based_extension
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory: class ModerationFactory:
__extension_instance: Moderation __extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict): def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config) self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod @classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict): def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
""" """
Validate the incoming form config data. Validate the incoming form config data.
@ -24,7 +26,7 @@ class ModerationFactory:
# FIXME: mypy error, try to fix it instead of using type: ignore # FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore extension_class.validate_config(tenant_id, config) # type: ignore
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
""" """
Moderation for inputs. Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review After the user inputs, this method will be called to perform sensitive content review

View File

@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords" name: str = "keywords"
@classmethod @classmethod
def validate_config(cls, tenant_id: str, config: dict): def validate_config(cls, tenant_id: str, config: dict[str, Any]):
""" """
Validate the incoming form config data. Validate the incoming form config data.
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100: if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100") raise ValueError("the number of rows for the keywords must be less than 100")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False flagged = False
preset_response = "" preset_response = ""
if self.config is None: if self.config is None:
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
) )
def _is_violated(self, inputs: dict, keywords_list: list) -> bool: def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool: def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from core.model_manager import ModelManager from core.model_manager import ModelManager
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation" name: str = "openai_moderation"
@classmethod @classmethod
def validate_config(cls, tenant_id: str, config: dict): def validate_config(cls, tenant_id: str, config: dict[str, Any]):
""" """
Validate the incoming form config data. Validate the incoming form config data.
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
""" """
cls._validate_inputs_and_outputs_config(config, True) cls._validate_inputs_and_outputs_config(config, True)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False flagged = False
preset_response = "" preset_response = ""
if self.config is None: if self.config is None:
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
) )
def _is_violated(self, inputs: dict): def _is_violated(self, inputs: dict[str, Any]):
text = "\n".join(str(inputs.values())) text = "\n".join(str(inputs.values()))
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(

View File

@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True) logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}") raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
"""Construct LLM attributes with passed prompts for Arize/Phoenix.""" """Construct LLM attributes with passed prompts for Arize/Phoenix."""
attributes: dict[str, str] = {} attributes: dict[str, str] = {}
@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}" path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
set_attribute(path, value) set_attribute(path, value)
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None: def set_tool_call_attributes(
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
) -> None:
"""Extract and assign tool call details safely.""" """Extract and assign tool call details safely."""
if not tool_call: if not tool_call:
return return

View File

@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance):
) )
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
@staticmethod
def _get_completion_start_time(
start_time: datetime | None, time_to_first_token: float | int | None
) -> datetime | None:
"""Convert a relative TTFT value in seconds into Langfuse's absolute completion start time."""
if start_time is None or time_to_first_token is None:
return None
try:
ttft_seconds = float(time_to_first_token)
except (TypeError, ValueError):
return None
if ttft_seconds < 0:
return None
return start_time + timedelta(seconds=ttft_seconds)
def trace(self, trace_info: BaseTraceInfo): def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo): if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info) self.workflow_trace(trace_info)
@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance):
total_token = metadata.get("total_tokens", 0) total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0 prompt_tokens = 0
completion_tokens = 0 completion_tokens = 0
completion_start_time = None
try: try:
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) usage_data = process_data.get("usage")
if not isinstance(usage_data, dict):
usage_data = outputs.get("usage")
if not isinstance(usage_data, dict):
usage_data = {}
prompt_tokens = usage_data.get("prompt_tokens", 0) prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0) completion_tokens = usage_data.get("completion_tokens", 0)
completion_start_time = self._get_completion_start_time(
created_at, usage_data.get("time_to_first_token")
)
except Exception: except Exception:
logger.error("Failed to extract usage", exc_info=True) logger.error("Failed to extract usage", exc_info=True)
@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_id=trace_id, trace_id=trace_id,
model=process_data.get("model_name"), model=process_data.get("model_name"),
start_time=created_at, start_time=created_at,
completion_start_time=completion_start_time,
end_time=finished_at, end_time=finished_at,
input=inputs, input=inputs,
output=outputs, output=outputs,
@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance):
unit=UnitEnum.TOKENS, unit=UnitEnum.TOKENS,
totalCost=message_data.total_price, totalCost=message_data.total_price,
) )
completion_start_time = self._get_completion_start_time(
trace_info.start_time,
trace_info.gen_ai_server_time_to_first_token,
)
langfuse_generation_data = LangfuseGeneration( langfuse_generation_data = LangfuseGeneration(
name="llm", name="llm",
trace_id=trace_id, trace_id=trace_id,
start_time=trace_info.start_time, start_time=trace_info.start_time,
completion_start_time=completion_start_time,
end_time=trace_info.end_time, end_time=trace_info.end_time,
model=message_data.model_id, model=message_data.model_id,
input=trace_info.inputs, input=trace_info.inputs,

View File

@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance):
return inputs, attributes return inputs, attributes
def _parse_knowledge_retrieval_outputs(self, outputs: dict): def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
"""Parse KR outputs and attributes from KR workflow node""" """Parse KR outputs and attributes from KR workflow node"""
retrieved = outputs.get("result", []) retrieved = outputs.get("result", [])
@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance):
end_time_ns=datetime_to_nanoseconds(trace_info.end_time), end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
) )
def _get_message_user_id(self, metadata: dict) -> str | None: def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
if (end_user_id := metadata.get("from_end_user_id")) and ( if (end_user_id := metadata.get("from_end_user_id")) and (
end_user_data := db.session.get(EndUser, end_user_id) end_user_data := db.session.get(EndUser, end_user_id)
): ):
@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance):
} }
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload] return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
def _set_trace_metadata(self, span: Span, metadata: dict): def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
token = None token = None
try: try:
# NB: Set span in context such that we can use update_current_trace() API # NB: Set span in context such that we can use update_current_trace() API
@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance):
return messages return messages
return prompts # Fallback to original format return prompts # Fallback to original format
def _parse_single_message(self, item: dict): def _parse_single_message(self, item: dict[str, Any]):
"""Postprocess single message format to be standard chat message""" """Postprocess single message format to be standard chat message"""
role = item.get("role", "user") role = item.get("role", "user")
msg = {"role": role, "content": item.get("text", "")} msg = {"role": role, "content": item.get("text", "")}

View File

@ -3,7 +3,7 @@ import logging
import os import os
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import cast from typing import Any, cast
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from opik import Opik, Trace from opik import Opik, Trace
@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance):
self.add_span(span_data) self.add_span(span_data)
def add_trace(self, opik_trace_data: dict) -> Trace: def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
try: try:
trace = self.opik_client.trace(**opik_trace_data) trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully") logger.debug("Opik Trace created successfully")
@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance):
except Exception as e: except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}") raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict): def add_span(self, opik_span_data: dict[str, Any]):
try: try:
self.opik_client.span(**opik_span_data) self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully") logger.debug("Opik Span created successfully")

View File

@ -324,7 +324,7 @@ class OpsTraceManager:
@classmethod @classmethod
def encrypt_tracing_config( def encrypt_tracing_config(
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any], current_trace_config=None
): ):
""" """
Encrypt tracing config. Encrypt tracing config.
@ -363,7 +363,7 @@ class OpsTraceManager:
return encrypted_config.model_dump() return encrypted_config.model_dump()
@classmethod @classmethod
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict): def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
""" """
Decrypt tracing config Decrypt tracing config
:param tenant_id: tenant id :param tenant_id: tenant id
@ -408,7 +408,7 @@ class OpsTraceManager:
return dict(decrypted_config) return dict(decrypted_config)
@classmethod @classmethod
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict): def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict[str, Any]):
""" """
Decrypt tracing config Decrypt tracing config
:param tracing_provider: tracing provider :param tracing_provider: tracing provider
@ -581,7 +581,7 @@ class OpsTraceManager:
return app_trace_config return app_trace_config
@staticmethod @staticmethod
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str): def check_trace_config_is_effective(tracing_config: dict[str, Any], tracing_provider: str):
""" """
Check trace config is effective Check trace config is effective
:param tracing_config: tracing config :param tracing_config: tracing config
@ -596,7 +596,7 @@ class OpsTraceManager:
return trace_instance(config).api_check() return trace_instance(config).api_check()
@staticmethod @staticmethod
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): def get_trace_config_project_key(tracing_config: dict[str, Any], tracing_provider: str):
""" """
get trace config is project key get trace config is project key
:param tracing_config: tracing config :param tracing_config: tracing config
@ -611,7 +611,7 @@ class OpsTraceManager:
return trace_instance(config).get_project_key() return trace_instance(config).get_project_key()
@staticmethod @staticmethod
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): def get_trace_config_project_url(tracing_config: dict[str, Any], tracing_provider: str):
""" """
get trace config is project key get trace config is project key
:param tracing_config: tracing config :param tracing_config: tracing config
@ -1322,8 +1322,8 @@ class TraceTask:
error=error, error=error,
) )
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict[str, Any]:
node_data: dict = kwargs.get("node_execution_data", {}) node_data: dict[str, Any] = kwargs.get("node_execution_data", {})
if not node_data: if not node_data:
return {} return {}
@ -1431,7 +1431,7 @@ class TraceTask:
return node_trace return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump()) return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict: def _extract_streaming_metrics(self, message_data) -> dict[str, Any]:
if not message_data.message_metadata: if not message_data.message_metadata:
return {} return {}

View File

@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity):
entity of an endpoint entity of an endpoint
""" """
settings: dict settings: dict[str, Any]
tenant_id: str tenant_id: str
plugin_id: str plugin_id: str
expired_at: datetime expired_at: datetime

View File

@ -1,3 +1,5 @@
from typing import Any
from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.entities.provider_entities import ProviderEntity
from pydantic import BaseModel, Field, computed_field, model_validator from pydantic import BaseModel, Field, computed_field, model_validator
@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def transform_declaration(cls, data: dict): def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]:
if "endpoint" in data and not data["endpoint"]: if "endpoint" in data and not data["endpoint"]:
del data["endpoint"] del data["endpoint"]
if "model" in data and not data["model"]: if "model" in data and not data["model"]:

View File

@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_category(cls, values: dict): def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]:
# auto detect category # auto detect category
if values.get("tool"): if values.get("tool"):
values["category"] = PluginCategory.Tool values["category"] = PluginCategory.Tool

View File

@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel):
""" """
result: bool result: bool
credentials: dict | None = None credentials: dict[str, Any] | None = None
class PluginModelSchemaEntity(BaseModel): class PluginModelSchemaEntity(BaseModel):

View File

@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel):
tool_type: Literal["builtin", "workflow", "api", "mcp"] tool_type: Literal["builtin", "workflow", "api", "mcp"]
provider: str provider: str
tool: str tool: str
tool_parameters: dict tool_parameters: dict[str, Any]
credential_id: str | None = None credential_id: str | None = None
@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel):
opt: Literal["encrypt", "decrypt", "clear"] opt: Literal["encrypt", "decrypt", "clear"]
namespace: Literal["endpoint"] namespace: Literal["endpoint"]
identity: str identity: str
data: dict = Field(default_factory=dict) data: dict[str, Any] = Field(default_factory=dict)
config: list[BasicProviderConfig] = Field(default_factory=list) config: list[BasicProviderConfig] = Field(default_factory=list)

View File

@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant. Fetch datasource providers for the given tenant.
""" """
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
if json_response.get("data"): if json_response.get("data"):
for provider in json_response.get("data", []): for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {} declaration = provider.get("declaration", {}) or {}
@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant. Fetch datasource providers for the given tenant.
""" """
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
if json_response.get("data"): if json_response.get("data"):
for provider in json_response.get("data", []): for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {} declaration = provider.get("declaration", {}) or {}
@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient):
tool_provider_id = DatasourceProviderID(provider_id) tool_provider_id = DatasourceProviderID(provider_id)
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
data = json_response.get("data") data = json_response.get("data")
if data: if data:
for datasource in data.get("declaration", {}).get("datasources", []): for datasource in data.get("declaration", {}).get("datasources", []):

View File

@ -1,3 +1,5 @@
from typing import Any
from core.plugin.entities.endpoint import EndpointEntityWithInstance from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError from core.plugin.impl.exc import PluginDaemonInternalServerError
@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient): class PluginEndpointClient(BasePluginClient):
def create_endpoint( def create_endpoint(
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict self,
tenant_id: str,
user_id: str,
plugin_unique_identifier: str,
name: str,
settings: dict[str, Any],
) -> bool: ) -> bool:
""" """
Create an endpoint for the given plugin. Create an endpoint for the given plugin.
@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient):
params={"plugin_id": plugin_id, "page": page, "page_size": page_size}, params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
) )
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict): def update_endpoint(
self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]
) -> bool:
""" """
Update the settings of the given endpoint. Update the settings of the given endpoint.
""" """

View File

@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
provider: str, provider: str,
model_type: str, model_type: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
) -> AIModelEntity | None: ) -> AIModelEntity | None:
""" """
Get model schema Get model schema
@ -80,7 +80,7 @@ class PluginModelClient(BasePluginClient):
return None return None
def validate_provider_credentials( def validate_provider_credentials(
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any]
) -> bool: ) -> bool:
""" """
validate the credentials of the provider validate the credentials of the provider
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
provider: str, provider: str,
model_type: str, model_type: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
) -> bool: ) -> bool:
""" """
validate the credentials of the provider validate the credentials of the provider
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict | None = None, model_parameters: dict[str, Any] | None = None,
tools: list[PromptMessageTool] | None = None, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stop: list[str] | None = None,
stream: bool = True, stream: bool = True,
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
provider: str, provider: str,
model_type: str, model_type: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None, tools: list[PromptMessageTool] | None = None,
) -> int: ) -> int:
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
texts: list[str], texts: list[str],
input_type: str, input_type: str,
) -> EmbeddingResult: ) -> EmbeddingResult:
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
documents: list[dict], documents: list[dict],
input_type: str, input_type: str,
) -> EmbeddingResult: ) -> EmbeddingResult:
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
texts: list[str], texts: list[str],
) -> list[int]: ) -> list[int]:
""" """
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
query: str, query: str,
docs: list[str], docs: list[str],
score_threshold: float | None = None, score_threshold: float | None = None,
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
query: MultimodalRerankInput, query: MultimodalRerankInput,
docs: list[MultimodalRerankInput], docs: list[MultimodalRerankInput],
score_threshold: float | None = None, score_threshold: float | None = None,
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
content_text: str, content_text: str,
voice: str, voice: str,
) -> Generator[bytes, None, None]: ) -> Generator[bytes, None, None]:
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
language: str | None = None, language: str | None = None,
): ):
""" """
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
file: IO[bytes], file: IO[bytes],
) -> str: ) -> str:
""" """
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str, plugin_id: str,
provider: str, provider: str,
model: str, model: str,
credentials: dict, credentials: dict[str, Any],
text: str, text: str,
) -> bool: ) -> bool:
""" """

View File

@ -1,4 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any
from requests import HTTPError from requests import HTTPError
@ -263,7 +264,7 @@ class PluginInstaller(BasePluginClient):
original_plugin_unique_identifier: str, original_plugin_unique_identifier: str,
new_plugin_unique_identifier: str, new_plugin_unique_identifier: str,
source: PluginInstallationSource, source: PluginInstallationSource,
meta: dict, meta: dict[str, Any],
) -> PluginInstallTaskStartResponse: ) -> PluginInstallTaskStartResponse:
""" """
Upgrade a plugin. Upgrade a plugin.

View File

@ -96,11 +96,11 @@ class SimplePromptTransform(PromptTransform):
app_mode: AppMode, app_mode: AppMode,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
pre_prompt: str, pre_prompt: str,
inputs: dict, inputs: dict[str, Any],
query: str | None = None, query: str | None = None,
context: str | None = None, context: str | None = None,
histories: str | None = None, histories: str | None = None,
) -> tuple[str, dict]: ) -> tuple[str, dict[str, Any]]:
# get prompt template # get prompt template
prompt_template_config = self.get_prompt_template( prompt_template_config = self.get_prompt_template(
app_mode=app_mode, app_mode=app_mode,
@ -187,7 +187,7 @@ class SimplePromptTransform(PromptTransform):
self, self,
app_mode: AppMode, app_mode: AppMode,
pre_prompt: str, pre_prompt: str,
inputs: dict, inputs: dict[str, Any],
query: str, query: str,
context: str | None, context: str | None,
files: Sequence["File"], files: Sequence["File"],
@ -234,7 +234,7 @@ class SimplePromptTransform(PromptTransform):
self, self,
app_mode: AppMode, app_mode: AppMode,
pre_prompt: str, pre_prompt: str,
inputs: dict, inputs: dict[str, Any],
query: str, query: str,
context: str | None, context: str | None,
files: Sequence["File"], files: Sequence["File"],

View File

@ -856,7 +856,7 @@ class ProviderManager:
secret_variables: list[str], secret_variables: list[str],
cache_type: ProviderCredentialsCacheType, cache_type: ProviderCredentialsCacheType,
is_provider: bool = False, is_provider: bool = False,
) -> dict: ) -> dict[str, Any]:
"""Get and decrypt credentials with caching.""" """Get and decrypt credentials with caching."""
credentials_cache = ProviderCredentialsCache( credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, tenant_id=tenant_id,

View File

@ -174,8 +174,8 @@ class RetrievalService:
cls, cls,
dataset_id: str, dataset_id: str,
query: str, query: str,
external_retrieval_model: dict | None = None, external_retrieval_model: dict[str, Any] | None = None,
metadata_filtering_conditions: dict | None = None, metadata_filtering_conditions: dict[str, Any] | None = None,
): ):
stmt = select(Dataset).where(Dataset.id == dataset_id) stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt) dataset = db.session.scalar(stmt)

View File

@ -0,0 +1,87 @@
"""Vector store backend discovery.
Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package
declares third-party dependencies and registers ``importlib`` entry points in group
``dify.vector_backends`` (see each package's ``pyproject.toml``).
Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol
remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``).
Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a
distribution; entry points take precedence when both exist.
After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``.
"""
from __future__ import annotations
import importlib
import logging
from importlib.metadata import entry_points
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
logger = logging.getLogger(__name__)
_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {}
# module_path:class_name — optional fallback when no distribution registers the backend.
_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {}
def clear_vector_factory_cache() -> None:
"""Drop lazily loaded factories (for tests or plugin reload)."""
_VECTOR_FACTORY_CACHE.clear()
def _vector_backend_entry_points():
return entry_points().select(group="dify.vector_backends")
def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None:
for ep in _vector_backend_entry_points():
if ep.name != vector_type:
continue
try:
loaded = ep.load()
except Exception:
logger.exception("Failed to load vector backend entry point %s", ep.name)
raise
return loaded # type: ignore[return-value]
return None
def _unsupported(vector_type: str) -> ValueError:
installed = sorted(ep.name for ep in _vector_backend_entry_points())
available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed."
return ValueError(
f"Vector store {vector_type!r} is not supported.{available_msg} "
"Install a plugin (uv sync --group vdb-all, or vdb-<backend> per api/pyproject.toml), "
"or register a dify.vector_backends entry point."
)
def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]:
target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type)
if not target:
raise _unsupported(vector_type)
module_path, _, attr = target.partition(":")
module = importlib.import_module(module_path)
return getattr(module, attr) # type: ignore[no-any-return]
def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]:
"""Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value."""
if vector_type in _VECTOR_FACTORY_CACHE:
return _VECTOR_FACTORY_CACHE[vector_type]
plugin_cls = _load_plugin_factory(vector_type)
if plugin_cls is not None:
_VECTOR_FACTORY_CACHE[vector_type] = plugin_cls
return plugin_cls
cls = _load_builtin_factory(vector_type)
_VECTOR_FACTORY_CACHE[vector_type] = cls
return cls

View File

@ -9,6 +9,7 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.cached_embedding import CacheEmbedding
@ -85,137 +86,7 @@ class Vector:
@staticmethod @staticmethod
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
match vector_type: return get_vector_factory_class(vector_type)
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.ALIBABACLOUD_MYSQL:
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVectorFactory,
)
return AlibabaCloudMySQLVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.VASTBASE:
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
return VastbaseVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.COUCHBASE:
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
return CouchbaseVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
return BaiduVectorFactory
case VectorType.VIKINGDB:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
return VikingDBVectorFactory
case VectorType.UPSTASH:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.LINDORM:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case VectorType.OPENGAUSS:
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
return OpenGaussFactory
case VectorType.TABLESTORE:
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
return TableStoreVectorFactory
case VectorType.HUAWEI_CLOUD:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory
case VectorType.MATRIXONE:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case VectorType.CLICKZETTA:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case VectorType.IRIS:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case VectorType.HOLOGRES:
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
return HologresVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
def create(self, texts: list | None = None, **kwargs): def create(self, texts: list | None = None, **kwargs):
if texts: if texts:

View File

@ -1,10 +1,19 @@
"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``).
:class:`AbstractVectorTest` and helper functions live here so package tests can import
``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the
``tests.*`` package.
The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is
auto-discovered by pytest for all package tests.
"""
import uuid import uuid
from unittest.mock import MagicMock
import pytest import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions import ext_redis
from models.dataset import Dataset from models.dataset import Dataset
@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document:
return doc return doc
@pytest.fixture
def setup_mock_redis():
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractVectorTest: class AbstractVectorTest:
vector: BaseVector
def __init__(self): def __init__(self):
self.vector = None
self.dataset_id = str(uuid.uuid4()) self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
self.example_doc_id = str(uuid.uuid4()) self.example_doc_id = str(uuid.uuid4())

View File

@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings):
return embedding_results # type: ignore return embedding_results # type: ignore
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
"""Embed multimodal documents.""" """Embed multimodal documents."""
# use doc embedding cache or store if not exists # use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"] file_id = multimodel_document["file_id"]

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
class Embeddings(ABC): class Embeddings(ABC):
@ -20,7 +21,7 @@ class Embeddings(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
"""Embed multimodal query.""" """Embed multimodal query."""
raise NotImplementedError raise NotImplementedError

View File

@ -1,6 +1,7 @@
"""Abstract interface for document loader implementations.""" """Abstract interface for document loader implementations."""
import csv import csv
from typing import Any
import pandas as pd import pandas as pd
@ -23,7 +24,7 @@ class CSVExtractor(BaseExtractor):
encoding: str | None = None, encoding: str | None = None,
autodetect_encoding: bool = False, autodetect_encoding: bool = False,
source_column: str | None = None, source_column: str | None = None,
csv_args: dict | None = None, csv_args: dict[str, Any] | None = None,
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path

View File

@ -54,8 +54,8 @@ class BaseAPIClient:
self, self,
method: str, method: str,
endpoint: str, endpoint: str,
query_params: dict | None = None, query_params: dict[str, Any] | None = None,
data: dict | None = None, data: dict[str, Any] | None = None,
**kwargs, **kwargs,
) -> Response: ) -> Response:
stream = kwargs.pop("stream", False) stream = kwargs.pop("stream", False)
@ -66,19 +66,25 @@ class BaseAPIClient:
return self.session.request(method, url, params=query_params, json=data, **kwargs) return self.session.request(method, url, params=query_params, json=data, **kwargs)
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs): def _get(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
return self._request("GET", endpoint, query_params=query_params, **kwargs) return self._request("GET", endpoint, query_params=query_params, **kwargs)
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): def _post(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs) return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs)
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): def _put(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs) return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs)
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs): def _delete(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
return self._request("DELETE", endpoint, query_params=query_params, **kwargs) return self._request("DELETE", endpoint, query_params=query_params, **kwargs)
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): def _patch(
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
):
return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs) return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs)
@ -99,7 +105,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
finally: finally:
response.close() response.close()
def process_response(self, response: Response) -> dict | bytes | list | None | Generator: def process_response(self, response: Response) -> dict[str, Any] | bytes | list[Any] | None | Generator:
if response.status_code == 401: if response.status_code == 401:
raise WaterCrawlAuthenticationError(response) raise WaterCrawlAuthenticationError(response)
@ -186,7 +192,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
yield from generator yield from generator
def get_crawl_request_results( def get_crawl_request_results(
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict[str, Any] | None = None
): ):
query_params = query_params or {} query_params = query_params or {}
query_params.update({"page": page or 1, "page_size": page_size or 25}) query_params.update({"page": page or 1, "page_size": page_size or 25})
@ -210,7 +216,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
if event_data["type"] == "result": if event_data["type"] == "result":
return event_data["data"] return event_data["data"]
def download_result(self, result_object: dict): def download_result(self, result_object: dict[str, Any]):
response = httpx.get(result_object["result"], timeout=None) response = httpx.get(result_object["result"], timeout=None)
try: try:
response.raise_for_status() response.raise_for_status()

View File

@ -120,7 +120,7 @@ class WaterCrawlProvider:
} }
def _get_results( def _get_results(
self, crawl_request_id: str, query_params: dict | None = None self, crawl_request_id: str, query_params: dict[str, Any] | None = None
) -> Generator[WatercrawlDocumentData, None, None]: ) -> Generator[WatercrawlDocumentData, None, None]:
page = 0 page = 0
page_size = 100 page_size = 100

View File

@ -875,7 +875,11 @@ class DatasetRetrieval:
return retrieval_resource_list return retrieval_resource_list
def _on_retrieval_end( def _on_retrieval_end(
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None self,
flask_app: Flask,
documents: list[Document],
message_id: str | None = None,
timer: dict[str, Any] | None = None,
): ):
"""Handle retrieval end.""" """Handle retrieval end."""
with flask_app.app_context(): with flask_app.app_context():
@ -980,7 +984,7 @@ class DatasetRetrieval:
self._send_trace_task(message_id, documents, timer) self._send_trace_task(message_id, documents, timer)
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None): def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None):
"""Send trace task if trace manager is available.""" """Send trace task if trace manager is available."""
trace_manager: TraceQueueManager | None = ( trace_manager: TraceQueueManager | None = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None self.application_generate_entity.trace_manager if self.application_generate_entity else None
@ -1142,7 +1146,7 @@ class DatasetRetrieval:
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
user_id: str, user_id: str,
inputs: dict, inputs: dict[str, Any],
) -> list[DatasetRetrieverBaseTool] | None: ) -> list[DatasetRetrieverBaseTool] | None:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset
@ -1337,7 +1341,7 @@ class DatasetRetrieval:
metadata_filtering_mode: str, metadata_filtering_mode: str,
metadata_model_config: ModelConfig, metadata_model_config: ModelConfig,
metadata_filtering_conditions: MetadataFilteringCondition | None, metadata_filtering_conditions: MetadataFilteringCondition | None,
inputs: dict, inputs: dict[str, Any],
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]: ) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
document_query = select(DatasetDocument).where( document_query = select(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.dataset_id.in_(dataset_ids),
@ -1417,7 +1421,7 @@ class DatasetRetrieval:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition return metadata_filter_document_ids, metadata_condition
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str:
if not inputs: if not inputs:
return text return text

View File

@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
return _case_routing return _case_routing
def __getattr__(name: str) -> dict: def __getattr__(name: str) -> Any:
"""Lazy module-level access to routing tables.""" """Lazy module-level access to routing tables."""
if name == "CASE_ROUTING": if name == "CASE_ROUTING":
return _get_case_routing() return _get_case_routing()

View File

@ -198,7 +198,7 @@ class Tool(ABC):
message=ToolInvokeMessage.TextMessage(text=text), message=ToolInvokeMessage.TextMessage(text=text),
) )
def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage: def create_blob_message(self, blob: bytes, meta: dict[str, Any] | None = None) -> ToolInvokeMessage:
""" """
create a blob message create a blob message
@ -212,7 +212,7 @@ class Tool(ABC):
meta=meta, meta=meta,
) )
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage: def create_json_message(self, object: dict[str, Any], suppress_output: bool = False) -> ToolInvokeMessage:
""" """
create a json message create a json message
""" """

View File

@ -149,7 +149,7 @@ class ToolInvokeMessage(BaseModel):
text: str text: str
class JsonMessage(BaseModel): class JsonMessage(BaseModel):
json_object: dict | list json_object: dict[str, Any] | list[Any]
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string") suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel): class BlobMessage(BaseModel):
@ -337,7 +337,7 @@ class ToolParameter(PluginParameter):
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: str | None = None llm_description: str | None = None
# MCP object and array type parameters use this field to store the schema # MCP object and array type parameters use this field to store the schema
input_schema: dict | None = None input_schema: dict[str, Any] | None = None
@classmethod @classmethod
def get_simple_instance( def get_simple_instance(
@ -463,7 +463,7 @@ class ToolInvokeMeta(BaseModel):
time_cost: float = Field(..., description="The time cost of the tool invoke") time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None error: str | None = None
tool_config: dict | None = None tool_config: dict[str, Any] | None = None
@classmethod @classmethod
def empty(cls) -> ToolInvokeMeta: def empty(cls) -> ToolInvokeMeta:

View File

@ -85,7 +85,8 @@ class ToolEngine:
invocation_meta_dict: dict[str, ToolInvokeMeta] = {} invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback( def message_callback(
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None] invocation_meta_dict: dict[str, Any],
messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None],
): ):
for message in messages: for message in messages:
if isinstance(message, ToolInvokeMeta): if isinstance(message, ToolInvokeMeta):
@ -200,7 +201,7 @@ class ToolEngine:
@staticmethod @staticmethod
def _invoke( def _invoke(
tool: Tool, tool: Tool,
tool_parameters: dict, tool_parameters: dict[str, Any],
user_id: str, user_id: str,
conversation_id: str | None = None, conversation_id: str | None = None,
app_id: str | None = None, app_id: str | None = None,

View File

@ -33,7 +33,7 @@ class DatasetRetrieverTool(Tool):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
user_id: str, user_id: str,
inputs: dict, inputs: dict[str, Any],
) -> list["DatasetRetrieverTool"]: ) -> list["DatasetRetrieverTool"]:
""" """
get dataset tool get dataset tool

View File

@ -4,6 +4,7 @@ from collections.abc import Generator
from datetime import date, datetime from datetime import date, datetime
from decimal import Decimal from decimal import Decimal
from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Any
from uuid import UUID from uuid import UUID
import numpy as np import numpy as np
@ -50,7 +51,7 @@ def safe_json_value(v):
return v return v
def safe_json_dict(d: dict): def safe_json_dict(d: dict[str, Any]):
if not isinstance(d, dict): if not isinstance(d, dict):
raise TypeError("safe_json_dict() expects a dictionary (dict) as input") raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
return {k: safe_json_value(v) for k, v in d.items()} return {k: safe_json_value(v) for k, v in d.items()}
@ -196,11 +197,11 @@ class ToolFileMessageTransformer:
@staticmethod @staticmethod
def _with_tool_file_meta( def _with_tool_file_meta(
meta: dict | None, meta: dict[str, Any] | None,
*, *,
tool_file_id: str | None = None, tool_file_id: str | None = None,
url: str | None = None, url: str | None = None,
) -> dict: ) -> dict[str, Any]:
normalized_meta = meta.copy() if meta is not None else {} normalized_meta = meta.copy() if meta is not None else {}
resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url)
if resolved_tool_file_id and "tool_file_id" not in normalized_meta: if resolved_tool_file_id and "tool_file_id" not in normalized_meta:

View File

@ -32,7 +32,7 @@ class OpenAPISpecDict(TypedDict):
class ApiBasedToolSchemaParser: class ApiBasedToolSchemaParser:
@staticmethod @staticmethod
def parse_openapi_to_tool_bundle( def parse_openapi_to_tool_bundle(
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None openapi: Mapping[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]: ) -> list[ApiToolBundle]:
warning = warning if warning is not None else {} warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {} extra_info = extra_info if extra_info is not None else {}
@ -236,7 +236,7 @@ class ApiBasedToolSchemaParser:
return value return value
@staticmethod @staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: def _get_tool_parameter_type(parameter: dict[str, Any]) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {} parameter = parameter or {}
typ: str | None = None typ: str | None = None
if parameter.get("format") == "binary": if parameter.get("format") == "binary":
@ -265,7 +265,7 @@ class ApiBasedToolSchemaParser:
@staticmethod @staticmethod
def parse_openapi_yaml_to_tool_bundle( def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None yaml: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]: ) -> list[ApiToolBundle]:
""" """
parse openapi yaml to tool bundle parse openapi yaml to tool bundle
@ -278,14 +278,14 @@ class ApiBasedToolSchemaParser:
warning = warning if warning is not None else {} warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {} extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml) openapi: dict[str, Any] = safe_load(yaml)
if openapi is None: if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.") raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod @staticmethod
def parse_swagger_to_openapi( def parse_swagger_to_openapi(
swagger: dict, extra_info: dict | None = None, warning: dict | None = None swagger: dict[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> OpenAPISpecDict: ) -> OpenAPISpecDict:
warning = warning or {} warning = warning or {}
""" """
@ -351,7 +351,7 @@ class ApiBasedToolSchemaParser:
@staticmethod @staticmethod
def parse_openai_plugin_json_to_tool_bundle( def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None json: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> list[ApiToolBundle]: ) -> list[ApiToolBundle]:
""" """
parse openapi plugin yaml to tool bundle parse openapi plugin yaml to tool bundle
@ -392,7 +392,7 @@ class ApiBasedToolSchemaParser:
@staticmethod @staticmethod
def auto_parse_to_tool_bundle( def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None content: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
""" """
auto parse to tool bundle auto parse to tool bundle

View File

@ -277,7 +277,7 @@ class WorkflowTool(Tool):
session.expunge(app) session.expunge(app)
return app return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
""" """
transform the tool parameters transform the tool parameters
@ -323,7 +323,7 @@ class WorkflowTool(Tool):
return parameters_result, files return parameters_result, files
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: def _extract_files(self, outputs: dict[str, Any]) -> tuple[dict[str, Any], list[File]]:
""" """
extract files from the result extract files from the result
@ -355,7 +355,7 @@ class WorkflowTool(Tool):
return result, files return result, files
def _update_file_mapping(self, file_dict: dict): def _update_file_mapping(self, file_dict: dict[str, Any]):
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
match transfer_method: match transfer_method:

View File

@ -43,15 +43,20 @@ class IndexProcessorProtocol(Protocol):
original_document_id: str, original_document_id: str,
chunks: Mapping[str, Any], chunks: Mapping[str, Any],
batch: Any, batch: Any,
summary_index_setting: dict | None = None, summary_index_setting: dict[str, Any] | None = None,
) -> IndexingResultDict: ... ) -> IndexingResultDict: ...
def get_preview_output( def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None self,
chunks: Any,
dataset_id: str,
document_id: str,
chunk_structure: str,
summary_index_setting: dict[str, Any] | None,
) -> Preview: ... ) -> Preview: ...
class SummaryIndexServiceProtocol(Protocol): class SummaryIndexServiceProtocol(Protocol):
def generate_and_vectorize_summary( def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict[str, Any] | None = None
) -> None: ... ) -> None: ...

View File

@ -1,4 +1,4 @@
from typing import Literal, Union from typing import Any, Literal, Union
from graphon.entities.base_node_data import BaseNodeData from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import NodeType from graphon.enums import NodeType
@ -16,7 +16,7 @@ class TriggerScheduleNodeData(BaseNodeData):
mode: str = Field(default="visual", description="Schedule mode: visual or cron") mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly") frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode") cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
visual_config: dict | None = Field(default=None, description="Visual configuration details") visual_config: dict[str, Any] | None = Field(default=None, description="Visual configuration details")
timezone: str = Field(default="UTC", description="Timezone for schedule execution") timezone: str = Field(default="UTC", description="Timezone for schedule execution")

View File

@ -75,7 +75,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs=outputs, outputs=outputs,
) )
def generate_file_var(self, param_name: str, file: dict): def generate_file_var(self, param_name: str, file: dict[str, Any]):
file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) file_id = resolve_file_record_id(file.get("reference") or file.get("related_id"))
transfer_method_value = file.get("transfer_method") transfer_method_value = file.get("transfer_method")
if transfer_method_value: if transfer_method_value:
@ -147,7 +147,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue continue
elif self.node_data.content_type == ContentType.BINARY: elif self.node_data.content_type == ContentType.BINARY:
raw_data: dict = webhook_data.get("body", {}).get("raw", {}) raw_data: dict[str, Any] = webhook_data.get("body", {}).get("raw", {})
file_var = self.generate_file_var(param_name, raw_data) file_var = self.generate_file_var(param_name, raw_data)
if file_var: if file_var:
outputs[param_name] = file_var outputs[param_name] = file_var

View File

@ -10,6 +10,7 @@ import tempfile
from collections.abc import Generator from collections.abc import Generator
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any
import clickzetta import clickzetta
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_config(cls, values: dict): def validate_config(cls, values: dict[str, Any]):
"""Validate the configuration values. """Validate the configuration values.
This method will first try to use CLICKZETTA_VOLUME_* environment variables, This method will first try to use CLICKZETTA_VOLUME_* environment variables,

View File

@ -65,7 +65,7 @@ class FileMetadata:
return data return data
@classmethod @classmethod
def from_dict(cls, data: dict) -> FileMetadata: def from_dict(cls, data: dict[str, Any]) -> FileMetadata:
"""Create instance from dictionary""" """Create instance from dictionary"""
data = data.copy() data = data.copy()
data["created_at"] = datetime.fromisoformat(data["created_at"]) data["created_at"] = datetime.fromisoformat(data["created_at"])
@ -459,7 +459,7 @@ class FileLifecycleManager:
newest_file=None, newest_file=None,
) )
def _create_version_backup(self, filename: str, metadata: dict): def _create_version_backup(self, filename: str, metadata: dict[str, Any]):
"""Create version backup""" """Create version backup"""
try: try:
# Read current file content # Read current file content
@ -487,7 +487,7 @@ class FileLifecycleManager:
logger.warning("Failed to load metadata: %s", e) logger.warning("Failed to load metadata: %s", e)
return {} return {}
def _save_metadata(self, metadata_dict: dict): def _save_metadata(self, metadata_dict: dict[str, Any]):
"""Save metadata file""" """Save metadata file"""
try: try:
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)

Some files were not shown because too many files have changed in this diff Show More