mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 14:14:17 +08:00
Merge remote-tracking branch 'origin/main'
# Conflicts: # api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py
This commit is contained in:
commit
03660c19ef
79
.agents/skills/e2e-cucumber-playwright/SKILL.md
Normal file
79
.agents/skills/e2e-cucumber-playwright/SKILL.md
Normal file
@ -0,0 +1,79 @@
|
||||
---
|
||||
name: e2e-cucumber-playwright
|
||||
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
|
||||
---
|
||||
|
||||
# Dify E2E Cucumber + Playwright
|
||||
|
||||
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
|
||||
|
||||
## Scope
|
||||
|
||||
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
|
||||
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
|
||||
- Do not use this skill for backend test or API review tasks under `api/`.
|
||||
|
||||
## Read Order
|
||||
|
||||
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
|
||||
2. Read only the files directly involved in the task:
|
||||
- target `.feature` files under `e2e/features/`
|
||||
- related step files under `e2e/features/step-definitions/`
|
||||
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
|
||||
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
|
||||
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
|
||||
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
|
||||
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
|
||||
|
||||
## Local Rules
|
||||
|
||||
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
|
||||
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
|
||||
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
|
||||
- Browser session behavior comes from `features/support/hooks.ts`:
|
||||
- default: authenticated session with shared storage state
|
||||
- `@unauthenticated`: clean browser context
|
||||
- `@authenticated`: readability/selective-run tag only unless implementation changes
|
||||
- `@fresh`: only for `e2e:full*` flows
|
||||
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Rebuild local context.
|
||||
- Inspect the target feature area.
|
||||
- Reuse an existing step when wording and behavior already match.
|
||||
- Add a new step only for a genuinely new user action or assertion.
|
||||
- Keep edits close to the current capability folder unless the step is broadly reusable.
|
||||
2. Write behavior-first scenarios.
|
||||
- Describe user-observable behavior, not DOM mechanics.
|
||||
- Keep each scenario focused on one workflow or outcome.
|
||||
- Keep scenarios independent and re-runnable.
|
||||
3. Write step definitions in the local style.
|
||||
- Keep one step to one user-visible action or one assertion.
|
||||
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
|
||||
- Scope locators to stable containers when the page has repeated elements.
|
||||
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
|
||||
4. Use Playwright in the local style.
|
||||
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
|
||||
- Use web-first `expect(...)` assertions.
|
||||
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
|
||||
5. Validate narrowly.
|
||||
- Run the narrowest tagged scenario or flow that exercises the change.
|
||||
- Run `pnpm -C e2e check`.
|
||||
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
- Does the scenario describe behavior rather than implementation?
|
||||
- Does it fit the current session model, tags, and `DifyWorld` usage?
|
||||
- Should an existing step be reused instead of adding a new one?
|
||||
- Are locators user-facing and assertions web-first?
|
||||
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
|
||||
- Does it document or implement behavior that differs from the real hooks or configuration?
|
||||
|
||||
Lead findings with correctness, flake risk, and architecture drift.
|
||||
|
||||
## References
|
||||
|
||||
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
|
||||
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)
|
||||
@ -0,0 +1,4 @@
|
||||
interface:
|
||||
display_name: "E2E Cucumber + Playwright"
|
||||
short_description: "Write and review Dify E2E scenarios."
|
||||
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."
|
||||
@ -0,0 +1,93 @@
|
||||
# Cucumber Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://cucumber.io/docs/guides/10-minute-tutorial/
|
||||
- https://cucumber.io/docs/cucumber/step-definitions/
|
||||
- https://cucumber.io/docs/cucumber/cucumber-expressions/
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Treat scenarios as executable specifications
|
||||
|
||||
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- write what the user does and what should happen
|
||||
- avoid UI-internal wording such as selector details, DOM structure, or component names
|
||||
- keep language concrete enough that the scenario reads like living documentation
|
||||
|
||||
### 2. Keep scenarios focused
|
||||
|
||||
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
|
||||
|
||||
In Dify's suite, this means:
|
||||
|
||||
- one capability-focused scenario per feature path
|
||||
- no long setup chains when existing bootstrap or reusable steps already cover them
|
||||
- no hidden dependency on another scenario's side effects
|
||||
|
||||
### 3. Reuse steps, but only when behavior really matches
|
||||
|
||||
Good reuse reduces duplication. Bad reuse hides meaning.
|
||||
|
||||
Prefer reuse when:
|
||||
|
||||
- the user action is genuinely the same
|
||||
- the expected outcome is genuinely the same
|
||||
- the wording stays natural across features
|
||||
|
||||
Write a new step when:
|
||||
|
||||
- the behavior is materially different
|
||||
- reusing the old wording would make the scenario misleading
|
||||
- a supposedly generic step would become an implementation-detail wrapper
|
||||
|
||||
### 4. Prefer Cucumber Expressions
|
||||
|
||||
Use Cucumber Expressions for parameters unless regex is clearly necessary.
|
||||
|
||||
Common examples:
|
||||
|
||||
- `{string}` for labels, names, and visible text
|
||||
- `{int}` for counts
|
||||
- `{float}` for decimal values
|
||||
- `{word}` only when the value is truly a single token
|
||||
|
||||
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
|
||||
|
||||
### 5. Keep step definitions thin and meaningful
|
||||
|
||||
Step definitions are glue between Gherkin and automation, not a second abstraction language.
|
||||
|
||||
For Dify:
|
||||
|
||||
- type `this` as `DifyWorld`
|
||||
- use `async function`
|
||||
- keep each step to one user-visible action or assertion
|
||||
- rely on `DifyWorld` and existing support code for shared context
|
||||
- avoid leaking cross-scenario state
|
||||
|
||||
### 6. Use tags intentionally
|
||||
|
||||
Tags should communicate run scope or session semantics, not become ad hoc metadata.
|
||||
|
||||
In Dify's current suite:
|
||||
|
||||
- capability tags group related scenarios
|
||||
- `@unauthenticated` changes session behavior
|
||||
- `@authenticated` is descriptive/selective, not a behavior switch by itself
|
||||
- `@fresh` belongs to reset/full-install flows only
|
||||
|
||||
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Does the scenario read like a real example of product behavior?
|
||||
- Are the steps behavior-oriented instead of implementation-oriented?
|
||||
- Is a reused step still truthful in this feature?
|
||||
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
|
||||
- Would a new reader understand the outcome without opening the step-definition file?
|
||||
@ -0,0 +1,96 @@
|
||||
# Playwright Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://playwright.dev/docs/best-practices
|
||||
- https://playwright.dev/docs/locators
|
||||
- https://playwright.dev/docs/test-assertions
|
||||
- https://playwright.dev/docs/browser-contexts
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Keep scenarios isolated
|
||||
|
||||
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- do not depend on another scenario having run first
|
||||
- do not persist ad hoc scenario state outside `DifyWorld`
|
||||
- do not couple ordinary scenarios to `@fresh` behavior
|
||||
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
|
||||
|
||||
### 2. Prefer user-facing locators
|
||||
|
||||
Playwright recommends built-in locators that reflect what users perceive on the page.
|
||||
|
||||
Preferred order in this repository:
|
||||
|
||||
1. `getByRole`
|
||||
2. `getByLabel`
|
||||
3. `getByPlaceholder`
|
||||
4. `getByText`
|
||||
5. `getByTestId` when an explicit test contract is the most stable option
|
||||
|
||||
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
|
||||
|
||||
Also remember:
|
||||
|
||||
- repeated content usually needs scoping to a stable container
|
||||
- exact text matching is often too brittle when role/name or label already exists
|
||||
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
|
||||
|
||||
### 3. Use web-first assertions
|
||||
|
||||
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
|
||||
|
||||
Prefer:
|
||||
|
||||
- `await expect(page).toHaveURL(...)`
|
||||
- `await expect(locator).toBeVisible()`
|
||||
- `await expect(locator).toBeHidden()`
|
||||
- `await expect(locator).toBeEnabled()`
|
||||
- `await expect(locator).toHaveText(...)`
|
||||
|
||||
Avoid:
|
||||
|
||||
- `expect(await locator.isVisible()).toBe(true)`
|
||||
- custom polling loops for DOM state
|
||||
- `waitForTimeout` as synchronization
|
||||
|
||||
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
|
||||
|
||||
### 4. Let actions wait for actionability
|
||||
|
||||
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
|
||||
|
||||
Good pattern:
|
||||
|
||||
- assert a meaningful visible state when that is part of the behavior
|
||||
- then click/fill/select via locator APIs
|
||||
|
||||
Bad pattern:
|
||||
|
||||
- stack arbitrary waits before every action
|
||||
- wait on unstable implementation details instead of the visible state the user cares about
|
||||
|
||||
### 5. Match debugging to the current suite
|
||||
|
||||
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
|
||||
|
||||
- full-page screenshots
|
||||
- page HTML
|
||||
- console errors
|
||||
- page errors
|
||||
|
||||
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Would this locator survive DOM refactors that do not change user-visible behavior?
|
||||
- Is this assertion using Playwright's retrying semantics?
|
||||
- Is any explicit wait masking a real readiness problem?
|
||||
- Does this code preserve per-scenario isolation?
|
||||
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?
|
||||
1
.claude/skills/e2e-cucumber-playwright
Symbolic link
1
.claude/skills/e2e-cucumber-playwright
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/e2e-cucumber-playwright
|
||||
7
.github/workflows/docker-build.yml
vendored
7
.github/workflows/docker-build.yml
vendored
@ -6,14 +6,7 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
|
||||
1
.github/workflows/main-ci.yml
vendored
1
.github/workflows/main-ci.yml
vendored
@ -92,6 +92,7 @@ jobs:
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -89,7 +89,7 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
10
.github/workflows/vdb-tests.yml
vendored
10
.github/workflows/vdb-tests.yml
vendored
@ -81,12 +81,12 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests
|
||||
|
||||
@ -69,8 +69,6 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
@ -84,7 +82,6 @@ ignore = [
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
@ -93,29 +90,16 @@ ignore = [
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"__init__.py" = [
|
||||
"F401", # unused-import
|
||||
"F811", # redefined-while-unused
|
||||
]
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
@ -21,8 +21,9 @@ RUN apt-get update \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev
|
||||
|
||||
# production stage
|
||||
|
||||
@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError("Qdrant URL is required.")
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
|
||||
default="public",
|
||||
)
|
||||
|
||||
HOLOGRES_TOKENIZER: TokenizerType = Field(
|
||||
HOLOGRES_TOKENIZER: str = Field(
|
||||
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
|
||||
HOLOGRES_DISTANCE_METHOD: str = Field(
|
||||
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
|
||||
default="Cosine",
|
||||
)
|
||||
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
|
||||
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
|
||||
default="rabitq",
|
||||
)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
|
||||
@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
prompt_args: AdvancedPromptTemplateArgs = {
|
||||
"app_mode": args.app_mode,
|
||||
"model_mode": args.model_mode,
|
||||
"model_name": args.model_name,
|
||||
"has_context": args.has_context,
|
||||
}
|
||||
return AdvancedPromptTemplateService.get_prompt(prompt_args)
|
||||
|
||||
@ -26,13 +26,13 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
|
||||
|
||||
class MCPServerUpdatePayload(BaseModel):
|
||||
id: str = Field(..., description="Server ID")
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
|
||||
|
||||
|
||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
@ -214,7 +214,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
@ -356,7 +360,11 @@ class WorkflowRunListApi(Resource):
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
|
||||
@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger)
|
||||
@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -42,12 +45,13 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@ -91,10 +95,12 @@ class LoginApi(Resource):
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
@ -110,14 +116,20 @@ class LoginApi(Resource):
|
||||
invitee_email = data.get("email") if data else None
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
_log_console_login_failure(
|
||||
email=normalized_email,
|
||||
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
|
||||
)
|
||||
raise InvalidEmailError()
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except Unauthorized as exc:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback(
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
||||
|
||||
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Console login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TypedDict
|
||||
|
||||
from flask import request
|
||||
@ -13,6 +14,14 @@ from services.billing_service import BillingService
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
class NotificationLangContent(TypedDict, total=False):
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
titlePicUrl: str
|
||||
|
||||
|
||||
class NotificationItemDict(TypedDict):
|
||||
notification_id: str | None
|
||||
frequency: str | None
|
||||
@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict):
|
||||
notifications: list[NotificationItemDict]
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
return (
|
||||
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
|
||||
)
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
@ -71,7 +82,7 @@ class NotificationApi(Resource):
|
||||
|
||||
notifications: list[NotificationItemDict] = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
item: NotificationItemDict = {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict:
|
||||
def _serialize_account(account) -> dict[str, Any]:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
from services.operation_service import OperationService, UtmInfo
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -158,14 +158,20 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
|
||||
try:
|
||||
variable_type = VariableEntityType(variable_type_raw)
|
||||
except ValueError as e:
|
||||
raise MCPRequestError(
|
||||
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
|
||||
) from e
|
||||
variable = item[variable_type_raw]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
@ -178,7 +184,7 @@ class MCPAppApi(Resource):
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
|
||||
@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
result: list[dict[str, Any]] = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
@ -5,6 +5,7 @@ Web App Human Input Form APIs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -58,10 +59,19 @@ def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class FormDefinitionPayload(TypedDict):
|
||||
form_content: Any
|
||||
inputs: Any
|
||||
resolved_default_values: dict[str, str]
|
||||
user_actions: Any
|
||||
expiration_time: int
|
||||
site: NotRequired[dict]
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
payload: FormDefinitionPayload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -20,7 +23,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import EmailStr
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
@ -29,9 +32,11 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@field_validator("password")
|
||||
@ -76,14 +81,18 @@ class LoginApi(Resource):
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
normalized_email = payload.email.lower()
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
if token_data is None:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != payload.code:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
try:
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
except Unauthorized as exc:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
if not account:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource):
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
|
||||
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Web login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -103,21 +104,23 @@ class PassportResource(Resource):
|
||||
return response
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
decoded: dict[str, Any] = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
||||
def exchange_token_for_existing_web_user(
|
||||
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
|
||||
):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
@ -113,12 +113,12 @@ class AppSiteInfo:
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
def serialize_site(site: Site) -> dict[str, Any]:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
@ -138,7 +138,9 @@ class DatasetConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for dataset feature
|
||||
|
||||
@ -172,7 +174,7 @@ class DatasetConfigManager:
|
||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||
|
||||
@classmethod
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
|
||||
"""
|
||||
Extract dataset config for legacy compatibility
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ class ModelConfigManager:
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict):
|
||||
def validate_model_completion_params(cls, cp: dict[str, Any]):
|
||||
# model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate pre_prompt and set defaults for prompt feature
|
||||
depending on the config['model']
|
||||
@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
|
||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||
|
||||
@classmethod
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Validate post_prompt and set defaults for prompt feature
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
|
||||
return config, related_config_keys
|
||||
|
||||
@classmethod
|
||||
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@ -164,7 +164,9 @@ class BasicVariablesConfigManager:
|
||||
return config, ["user_input_form"]
|
||||
|
||||
@classmethod
|
||||
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_external_data_tools_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for external data fetch feature
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class FileUploadConfigManager:
|
||||
return FileUploadConfig.model_validate(file_upload_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for file upload feature
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
@ -13,7 +15,7 @@ class AppConfigModel(BaseModel):
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -23,7 +25,7 @@ class MoreLikeThisConfigManager:
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class OpeningStatementConfigManager:
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for opening statement feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
@ -10,7 +13,7 @@ class RetrievalResourceConfigManager:
|
||||
return show_retrieve_source
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for retriever resource feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SpeechToTextConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SpeechToTextConfigManager:
|
||||
return speech_to_text
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for speech to text feature
|
||||
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
return suggested_questions_after_answer
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict):
|
||||
def convert(cls, config: dict[str, Any]):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -22,7 +24,7 @@ class TextToSpeechConfigManager:
|
||||
return text_to_speech
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for text to speech feature
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager):
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
def config_validate(
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
|
||||
@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: dict | None = None,
|
||||
next_page_parameters: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
|
||||
@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
extras: dict[str, Any] | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = Field(default_factory=dict)
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict | None = None
|
||||
extras: dict[str, Any] | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
|
||||
@ -14,7 +14,7 @@ class DatasourceApiEntity(BaseModel):
|
||||
description: I18nObject
|
||||
parameters: list[DatasourceParameter] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: dict | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
|
||||
@ -30,7 +30,7 @@ class DatasourceProviderApiEntityDict(TypedDict):
|
||||
icon: str | dict
|
||||
label: I18nObjectDict
|
||||
type: str
|
||||
team_credentials: dict | None
|
||||
team_credentials: dict[str, Any] | None
|
||||
is_team_authorization: bool
|
||||
allow_delete: bool
|
||||
datasources: list[Any]
|
||||
@ -45,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: str
|
||||
masked_credentials: dict | None = None
|
||||
original_credentials: dict | None = None
|
||||
masked_credentials: dict[str, Any] | None = None
|
||||
original_credentials: dict[str, Any] | None = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
|
||||
|
||||
@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel):
|
||||
identity: DatasourceIdentity
|
||||
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The label of the datasource")
|
||||
output_schema: dict | None = None
|
||||
output_schema: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
@ -192,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel):
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: str | None = None
|
||||
tool_config: dict | None = None
|
||||
tool_config: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> DatasourceInvokeMeta:
|
||||
@ -242,7 +242,7 @@ class OnlineDocumentPage(BaseModel):
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_name: str = Field(..., description="The page title")
|
||||
page_icon: dict | None = Field(None, description="The page icon")
|
||||
page_icon: dict[str, Any] | None = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
parent_id: str | None = Field(None, description="The parent page id")
|
||||
@ -301,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel):
|
||||
Get website crawl request
|
||||
"""
|
||||
|
||||
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||
crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters")
|
||||
|
||||
|
||||
class WebSiteInfoDetail(BaseModel):
|
||||
@ -358,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel):
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
files: list[OnlineDriveFile] = Field(..., description="The file list")
|
||||
is_truncated: bool = Field(False, description="Whether the result is truncated")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
@ -369,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
prefix: str = Field(..., description="The parent folder ID")
|
||||
max_keys: int = Field(20, description="Page size for pagination")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@ -37,7 +39,7 @@ class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: dict | None = None
|
||||
data_source_info: dict[str, Any] | None = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: str | None = None
|
||||
|
||||
@ -6,6 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get current credentials.
|
||||
|
||||
@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get provider credentials.
|
||||
|
||||
@ -317,7 +318,9 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -447,7 +450,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
@ -515,7 +518,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def update_provider_credential(
|
||||
self,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
):
|
||||
@ -760,7 +763,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_specific_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str
|
||||
) -> dict | None:
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@ -832,7 +835,9 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
|
||||
def get_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
|
||||
@ -872,7 +877,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
@ -939,7 +944,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
@ -1002,7 +1007,12 @@ class ProviderConfiguration(BaseModel):
|
||||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_name: str | None,
|
||||
credential_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
@ -1412,7 +1422,9 @@ class ProviderConfiguration(BaseModel):
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
|
||||
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
@ -1471,7 +1483,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
|
||||
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
|
||||
enabled: bool
|
||||
current_quota_type: ProviderQuotaType | None = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: dict | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict | None
|
||||
credentials: dict[str, Any] | None
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
credentials: dict[str, Any]
|
||||
credential_source_type: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class ExternalDataToolFactory:
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
self.__extension_instance = extension_class(
|
||||
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -15,7 +16,7 @@ class ProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> dict | None:
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
@ -33,7 +34,7 @@ class ProviderCredentialsCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict):
|
||||
def set(self, credentials: dict[str, Any]):
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC):
|
||||
"""Generate cache key based on subclass implementation"""
|
||||
pass
|
||||
|
||||
def get(self) -> dict | None:
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""Get cached provider credentials"""
|
||||
cached_credentials = redis_client.get(self.cache_key)
|
||||
if cached_credentials:
|
||||
@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
def get(self) -> dict | None:
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""Get cached provider credentials"""
|
||||
return None
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -18,7 +19,7 @@ class ToolParameterCache:
|
||||
f":identity_id:{identity_id}"
|
||||
)
|
||||
|
||||
def get(self) -> dict | None:
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
@ -36,7 +37,7 @@ class ToolParameterCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, parameters: dict):
|
||||
def set(self, parameters: dict[str, Any]):
|
||||
"""Cache model provider credentials."""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
|
||||
|
||||
|
||||
@ -735,7 +735,9 @@ class IndexingRunner:
|
||||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
|
||||
document_id: str,
|
||||
after_indexing_status: IndexingStatus,
|
||||
extra_update_params: dict[Any, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Update the document indexing status.
|
||||
@ -762,7 +764,7 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _update_segments_by_document(dataset_document_id: str, update_params: dict):
|
||||
def _update_segments_by_document(dataset_document_id: str, update_params: dict[Any, Any]):
|
||||
"""
|
||||
Update the document segment by document id.
|
||||
"""
|
||||
|
||||
@ -200,7 +200,7 @@ def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict,
|
||||
model_parameters: dict[str, Any],
|
||||
rules: list[ParameterRule],
|
||||
):
|
||||
"""
|
||||
@ -224,7 +224,7 @@ def _handle_native_json_schema(
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict, rules: list):
|
||||
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]):
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict):
|
||||
def remove_additional_properties(schema: dict[str, Any]):
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
|
||||
@ -77,7 +77,7 @@ class ModelInstance:
|
||||
|
||||
@staticmethod
|
||||
def _get_load_balancing_manager(
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
|
||||
) -> Optional["LBModelManager"]:
|
||||
"""
|
||||
Get load balancing model credentials
|
||||
@ -115,7 +115,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True] = True,
|
||||
@ -126,7 +126,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
@ -137,7 +137,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -147,7 +147,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -528,7 +528,7 @@ class LBModelManager:
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: dict | None = None,
|
||||
managed_credentials: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Load balancing model manager
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension
|
||||
|
||||
class ModerationInputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@ -23,7 +25,7 @@ class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -41,7 +43,7 @@ class ApiModeration(Moderation):
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -73,7 +75,7 @@ class ApiModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel):
|
||||
flagged: bool = False
|
||||
action: ModerationAction
|
||||
preset_response: str = ""
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@ -33,13 +34,13 @@ class Moderation(Extensible, ABC):
|
||||
|
||||
module: ExtensionModule = ExtensionModule.MODERATION
|
||||
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
|
||||
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -50,7 +51,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
@ -75,7 +76,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
|
||||
# inputs_config
|
||||
inputs_config = config.get("inputs_config")
|
||||
if not isinstance(inputs_config, dict):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
class ModerationFactory:
|
||||
__extension_instance: Moderation
|
||||
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
|
||||
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||
self.__extension_instance = extension_class(app_id, tenant_id, config)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -24,7 +26,7 @@ class ModerationFactory:
|
||||
# FIXME: mypy error, try to fix it instead of using type: ignore
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
After the user inputs, this method will be called to perform sensitive content review
|
||||
|
||||
@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
|
||||
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
|
||||
|
||||
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
@ -8,7 +10,7 @@ class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -18,7 +20,7 @@ class OpenAIModeration(Moderation):
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@ -49,7 +51,7 @@ class OpenAIModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
def _is_violated(self, inputs: dict[str, Any]):
|
||||
text = "\n".join(str(inputs.values()))
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
|
||||
@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
|
||||
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
|
||||
attributes: dict[str, str] = {}
|
||||
|
||||
@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
|
||||
set_attribute(path, value)
|
||||
|
||||
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
|
||||
def set_tool_call_attributes(
|
||||
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
|
||||
) -> None:
|
||||
"""Extract and assign tool call details safely."""
|
||||
if not tool_call:
|
||||
return
|
||||
|
||||
@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
@staticmethod
|
||||
def _get_completion_start_time(
|
||||
start_time: datetime | None, time_to_first_token: float | int | None
|
||||
) -> datetime | None:
|
||||
"""Convert a relative TTFT value in seconds into Langfuse's absolute completion start time."""
|
||||
if start_time is None or time_to_first_token is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
ttft_seconds = float(time_to_first_token)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if ttft_seconds < 0:
|
||||
return None
|
||||
|
||||
return start_time + timedelta(seconds=ttft_seconds)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
completion_start_time = None
|
||||
try:
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
usage_data = process_data.get("usage")
|
||||
if not isinstance(usage_data, dict):
|
||||
usage_data = outputs.get("usage")
|
||||
if not isinstance(usage_data, dict):
|
||||
usage_data = {}
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
completion_start_time = self._get_completion_start_time(
|
||||
created_at, usage_data.get("time_to_first_token")
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
trace_id=trace_id,
|
||||
model=process_data.get("model_name"),
|
||||
start_time=created_at,
|
||||
completion_start_time=completion_start_time,
|
||||
end_time=finished_at,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
unit=UnitEnum.TOKENS,
|
||||
totalCost=message_data.total_price,
|
||||
)
|
||||
completion_start_time = self._get_completion_start_time(
|
||||
trace_info.start_time,
|
||||
trace_info.gen_ai_server_time_to_first_token,
|
||||
)
|
||||
|
||||
langfuse_generation_data = LangfuseGeneration(
|
||||
name="llm",
|
||||
trace_id=trace_id,
|
||||
start_time=trace_info.start_time,
|
||||
completion_start_time=completion_start_time,
|
||||
end_time=trace_info.end_time,
|
||||
model=message_data.model_id,
|
||||
input=trace_info.inputs,
|
||||
|
||||
@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
|
||||
return inputs, attributes
|
||||
|
||||
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
|
||||
def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
|
||||
"""Parse KR outputs and attributes from KR workflow node"""
|
||||
retrieved = outputs.get("result", [])
|
||||
|
||||
@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
|
||||
)
|
||||
|
||||
def _get_message_user_id(self, metadata: dict) -> str | None:
|
||||
def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
|
||||
if (end_user_id := metadata.get("from_end_user_id")) and (
|
||||
end_user_data := db.session.get(EndUser, end_user_id)
|
||||
):
|
||||
@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
}
|
||||
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
|
||||
|
||||
def _set_trace_metadata(self, span: Span, metadata: dict):
|
||||
def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
|
||||
token = None
|
||||
try:
|
||||
# NB: Set span in context such that we can use update_current_trace() API
|
||||
@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance):
|
||||
return messages
|
||||
return prompts # Fallback to original format
|
||||
|
||||
def _parse_single_message(self, item: dict):
|
||||
def _parse_single_message(self, item: dict[str, Any]):
|
||||
"""Postprocess single message format to be standard chat message"""
|
||||
role = item.get("role", "user")
|
||||
msg = {"role": role, "content": item.get("text", "")}
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from opik import Opik, Trace
|
||||
@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def add_trace(self, opik_trace_data: dict) -> Trace:
|
||||
def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
|
||||
try:
|
||||
trace = self.opik_client.trace(**opik_trace_data)
|
||||
logger.debug("Opik Trace created successfully")
|
||||
@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, opik_span_data: dict):
|
||||
def add_span(self, opik_span_data: dict[str, Any]):
|
||||
try:
|
||||
self.opik_client.span(**opik_span_data)
|
||||
logger.debug("Opik Span created successfully")
|
||||
|
||||
@ -324,7 +324,7 @@ class OpsTraceManager:
|
||||
|
||||
@classmethod
|
||||
def encrypt_tracing_config(
|
||||
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
|
||||
cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any], current_trace_config=None
|
||||
):
|
||||
"""
|
||||
Encrypt tracing config.
|
||||
@ -363,7 +363,7 @@ class OpsTraceManager:
|
||||
return encrypted_config.model_dump()
|
||||
|
||||
@classmethod
|
||||
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
|
||||
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
|
||||
"""
|
||||
Decrypt tracing config
|
||||
:param tenant_id: tenant id
|
||||
@ -408,7 +408,7 @@ class OpsTraceManager:
|
||||
return dict(decrypted_config)
|
||||
|
||||
@classmethod
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict[str, Any]):
|
||||
"""
|
||||
Decrypt tracing config
|
||||
:param tracing_provider: tracing provider
|
||||
@ -581,7 +581,7 @@ class OpsTraceManager:
|
||||
return app_trace_config
|
||||
|
||||
@staticmethod
|
||||
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
|
||||
def check_trace_config_is_effective(tracing_config: dict[str, Any], tracing_provider: str):
|
||||
"""
|
||||
Check trace config is effective
|
||||
:param tracing_config: tracing config
|
||||
@ -596,7 +596,7 @@ class OpsTraceManager:
|
||||
return trace_instance(config).api_check()
|
||||
|
||||
@staticmethod
|
||||
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
|
||||
def get_trace_config_project_key(tracing_config: dict[str, Any], tracing_provider: str):
|
||||
"""
|
||||
get trace config is project key
|
||||
:param tracing_config: tracing config
|
||||
@ -611,7 +611,7 @@ class OpsTraceManager:
|
||||
return trace_instance(config).get_project_key()
|
||||
|
||||
@staticmethod
|
||||
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
|
||||
def get_trace_config_project_url(tracing_config: dict[str, Any], tracing_provider: str):
|
||||
"""
|
||||
get trace config is project key
|
||||
:param tracing_config: tracing config
|
||||
@ -1322,8 +1322,8 @@ class TraceTask:
|
||||
error=error,
|
||||
)
|
||||
|
||||
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
|
||||
node_data: dict = kwargs.get("node_execution_data", {})
|
||||
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict[str, Any]:
|
||||
node_data: dict[str, Any] = kwargs.get("node_execution_data", {})
|
||||
if not node_data:
|
||||
return {}
|
||||
|
||||
@ -1431,7 +1431,7 @@ class TraceTask:
|
||||
return node_trace
|
||||
return DraftNodeExecutionTrace(**node_trace.model_dump())
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
def _extract_streaming_metrics(self, message_data) -> dict[str, Any]:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity):
|
||||
entity of an endpoint
|
||||
"""
|
||||
|
||||
settings: dict
|
||||
settings: dict[str, Any]
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
expired_at: datetime
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
|
||||
@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_declaration(cls, data: dict):
|
||||
def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]:
|
||||
if "endpoint" in data and not data["endpoint"]:
|
||||
del data["endpoint"]
|
||||
if "model" in data and not data["model"]:
|
||||
|
||||
@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_category(cls, values: dict):
|
||||
def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
# auto detect category
|
||||
if values.get("tool"):
|
||||
values["category"] = PluginCategory.Tool
|
||||
|
||||
@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel):
|
||||
"""
|
||||
|
||||
result: bool
|
||||
credentials: dict | None = None
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PluginModelSchemaEntity(BaseModel):
|
||||
|
||||
@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel):
|
||||
tool_type: Literal["builtin", "workflow", "api", "mcp"]
|
||||
provider: str
|
||||
tool: str
|
||||
tool_parameters: dict
|
||||
tool_parameters: dict[str, Any]
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel):
|
||||
opt: Literal["encrypt", "decrypt", "clear"]
|
||||
namespace: Literal["endpoint"]
|
||||
identity: str
|
||||
data: dict = Field(default_factory=dict)
|
||||
data: dict[str, Any] = Field(default_factory=dict)
|
||||
config: list[BasicProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
Fetch datasource providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
if json_response.get("data"):
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
|
||||
tool_provider_id = DatasourceProviderID(provider_id)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for datasource in data.get("declaration", {}).get("datasources", []):
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError
|
||||
|
||||
class PluginEndpointClient(BasePluginClient):
|
||||
def create_endpoint(
|
||||
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
name: str,
|
||||
settings: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Create an endpoint for the given plugin.
|
||||
@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient):
|
||||
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
def update_endpoint(
|
||||
self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Update the settings of the given endpoint.
|
||||
"""
|
||||
|
||||
@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
@ -80,7 +80,7 @@ class PluginModelClient(BasePluginClient):
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
|
||||
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict | None = None,
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
documents: list[dict],
|
||||
input_type: str,
|
||||
) -> EmbeddingResult:
|
||||
@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
"""
|
||||
@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: float | None = None,
|
||||
@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
query: MultimodalRerankInput,
|
||||
docs: list[MultimodalRerankInput],
|
||||
score_threshold: float | None = None,
|
||||
@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Generator[bytes, None, None]:
|
||||
@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
language: str | None = None,
|
||||
):
|
||||
"""
|
||||
@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
"""
|
||||
@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
credentials: dict[str, Any],
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
@ -263,7 +264,7 @@ class PluginInstaller(BasePluginClient):
|
||||
original_plugin_unique_identifier: str,
|
||||
new_plugin_unique_identifier: str,
|
||||
source: PluginInstallationSource,
|
||||
meta: dict,
|
||||
meta: dict[str, Any],
|
||||
) -> PluginInstallTaskStartResponse:
|
||||
"""
|
||||
Upgrade a plugin.
|
||||
|
||||
@ -96,11 +96,11 @@ class SimplePromptTransform(PromptTransform):
|
||||
app_mode: AppMode,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str | None = None,
|
||||
context: str | None = None,
|
||||
histories: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
# get prompt template
|
||||
prompt_template_config = self.get_prompt_template(
|
||||
app_mode=app_mode,
|
||||
@ -187,7 +187,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str,
|
||||
context: str | None,
|
||||
files: Sequence["File"],
|
||||
@ -234,7 +234,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
self,
|
||||
app_mode: AppMode,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
query: str,
|
||||
context: str | None,
|
||||
files: Sequence["File"],
|
||||
|
||||
@ -856,7 +856,7 @@ class ProviderManager:
|
||||
secret_variables: list[str],
|
||||
cache_type: ProviderCredentialsCacheType,
|
||||
is_provider: bool = False,
|
||||
) -> dict:
|
||||
) -> dict[str, Any]:
|
||||
"""Get and decrypt credentials with caching."""
|
||||
credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -174,8 +174,8 @@ class RetrievalService:
|
||||
cls,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: dict | None = None,
|
||||
metadata_filtering_conditions: dict | None = None,
|
||||
external_retrieval_model: dict[str, Any] | None = None,
|
||||
metadata_filtering_conditions: dict[str, Any] | None = None,
|
||||
):
|
||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
|
||||
87
api/core/rag/datasource/vdb/vector_backend_registry.py
Normal file
87
api/core/rag/datasource/vdb/vector_backend_registry.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Vector store backend discovery.
|
||||
|
||||
Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package
|
||||
declares third-party dependencies and registers ``importlib`` entry points in group
|
||||
``dify.vector_backends`` (see each package's ``pyproject.toml``).
|
||||
|
||||
Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol
|
||||
remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``).
|
||||
|
||||
Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a
|
||||
distribution; entry points take precedence when both exist.
|
||||
|
||||
After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from importlib.metadata import entry_points
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {}
|
||||
|
||||
# module_path:class_name — optional fallback when no distribution registers the backend.
|
||||
_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {}
|
||||
|
||||
|
||||
def clear_vector_factory_cache() -> None:
|
||||
"""Drop lazily loaded factories (for tests or plugin reload)."""
|
||||
_VECTOR_FACTORY_CACHE.clear()
|
||||
|
||||
|
||||
def _vector_backend_entry_points():
|
||||
return entry_points().select(group="dify.vector_backends")
|
||||
|
||||
|
||||
def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None:
|
||||
for ep in _vector_backend_entry_points():
|
||||
if ep.name != vector_type:
|
||||
continue
|
||||
try:
|
||||
loaded = ep.load()
|
||||
except Exception:
|
||||
logger.exception("Failed to load vector backend entry point %s", ep.name)
|
||||
raise
|
||||
return loaded # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
|
||||
def _unsupported(vector_type: str) -> ValueError:
|
||||
installed = sorted(ep.name for ep in _vector_backend_entry_points())
|
||||
available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed."
|
||||
return ValueError(
|
||||
f"Vector store {vector_type!r} is not supported.{available_msg} "
|
||||
"Install a plugin (uv sync --group vdb-all, or vdb-<backend> per api/pyproject.toml), "
|
||||
"or register a dify.vector_backends entry point."
|
||||
)
|
||||
|
||||
|
||||
def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type)
|
||||
if not target:
|
||||
raise _unsupported(vector_type)
|
||||
module_path, _, attr = target.partition(":")
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, attr) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
"""Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value."""
|
||||
if vector_type in _VECTOR_FACTORY_CACHE:
|
||||
return _VECTOR_FACTORY_CACHE[vector_type]
|
||||
|
||||
plugin_cls = _load_plugin_factory(vector_type)
|
||||
if plugin_cls is not None:
|
||||
_VECTOR_FACTORY_CACHE[vector_type] = plugin_cls
|
||||
return plugin_cls
|
||||
|
||||
cls = _load_builtin_factory(vector_type)
|
||||
_VECTOR_FACTORY_CACHE[vector_type] = cls
|
||||
return cls
|
||||
@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
@ -85,137 +86,7 @@ class Vector:
|
||||
|
||||
@staticmethod
|
||||
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
match vector_type:
|
||||
case VectorType.CHROMA:
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
|
||||
|
||||
return ChromaVectorFactory
|
||||
case VectorType.MILVUS:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
|
||||
return MilvusVectorFactory
|
||||
case VectorType.ALIBABACLOUD_MYSQL:
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVectorFactory,
|
||||
)
|
||||
|
||||
return AlibabaCloudMySQLVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
|
||||
return MyScaleVectorFactory
|
||||
case VectorType.PGVECTOR:
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||
|
||||
return PGVectorFactory
|
||||
case VectorType.VASTBASE:
|
||||
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
|
||||
|
||||
return VastbaseVectorFactory
|
||||
case VectorType.PGVECTO_RS:
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
||||
|
||||
return PGVectoRSFactory
|
||||
case VectorType.QDRANT:
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
||||
|
||||
return QdrantVectorFactory
|
||||
case VectorType.RELYT:
|
||||
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
||||
|
||||
return RelytVectorFactory
|
||||
case VectorType.ELASTICSEARCH:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.ELASTICSEARCH_JA:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||
ElasticSearchJaVectorFactory,
|
||||
)
|
||||
|
||||
return ElasticSearchJaVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
return TiDBVectorFactory
|
||||
case VectorType.WEAVIATE:
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
||||
|
||||
return WeaviateVectorFactory
|
||||
case VectorType.TENCENT:
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
|
||||
|
||||
return TencentVectorFactory
|
||||
case VectorType.ORACLE:
|
||||
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
|
||||
|
||||
return OracleVectorFactory
|
||||
case VectorType.OPENSEARCH:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case VectorType.COUCHBASE:
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
|
||||
|
||||
return CouchbaseVectorFactory
|
||||
case VectorType.BAIDU:
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
|
||||
|
||||
return BaiduVectorFactory
|
||||
case VectorType.VIKINGDB:
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
|
||||
|
||||
return VikingDBVectorFactory
|
||||
case VectorType.UPSTASH:
|
||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
|
||||
|
||||
return UpstashVectorFactory
|
||||
case VectorType.TIDB_ON_QDRANT:
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case VectorType.LINDORM:
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
|
||||
|
||||
return LindormVectorStoreFactory
|
||||
case VectorType.OCEANBASE | VectorType.SEEKDB:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
return OceanBaseVectorFactory
|
||||
case VectorType.OPENGAUSS:
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
|
||||
|
||||
return OpenGaussFactory
|
||||
case VectorType.TABLESTORE:
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
|
||||
|
||||
return TableStoreVectorFactory
|
||||
case VectorType.HUAWEI_CLOUD:
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
|
||||
|
||||
return HuaweiCloudVectorFactory
|
||||
case VectorType.MATRIXONE:
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case VectorType.CLICKZETTA:
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case VectorType.IRIS:
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
|
||||
|
||||
return IrisVectorFactory
|
||||
case VectorType.HOLOGRES:
|
||||
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
|
||||
|
||||
return HologresVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
return get_vector_factory_class(vector_type)
|
||||
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
if texts:
|
||||
|
||||
@ -1,10 +1,19 @@
|
||||
"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``).
|
||||
|
||||
:class:`AbstractVectorTest` and helper functions live here so package tests can import
|
||||
``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the
|
||||
``tests.*`` package.
|
||||
|
||||
The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is
|
||||
auto-discovered by pytest for all package tests.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document:
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis():
|
||||
# get
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
|
||||
# set
|
||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
||||
|
||||
# lock
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = mock_redis_lock
|
||||
|
||||
|
||||
class AbstractVectorTest:
|
||||
vector: BaseVector
|
||||
|
||||
def __init__(self):
|
||||
self.vector = None
|
||||
self.dataset_id = str(uuid.uuid4())
|
||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
|
||||
self.example_doc_id = str(uuid.uuid4())
|
||||
@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return embedding_results # type: ignore
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
file_id = multimodel_document["file_id"]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
@ -20,7 +21,7 @@ class Embeddings(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def embed_multimodal_query(self, multimodel_document: dict) -> list[float]:
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
"""Embed multimodal query."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import csv
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -23,7 +24,7 @@ class CSVExtractor(BaseExtractor):
|
||||
encoding: str | None = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: str | None = None,
|
||||
csv_args: dict | None = None,
|
||||
csv_args: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
@ -54,8 +54,8 @@ class BaseAPIClient:
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
query_params: dict | None = None,
|
||||
data: dict | None = None,
|
||||
query_params: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> Response:
|
||||
stream = kwargs.pop("stream", False)
|
||||
@ -66,19 +66,25 @@ class BaseAPIClient:
|
||||
|
||||
return self.session.request(method, url, params=query_params, json=data, **kwargs)
|
||||
|
||||
def _get(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
def _get(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
|
||||
return self._request("GET", endpoint, query_params=query_params, **kwargs)
|
||||
|
||||
def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
def _post(
|
||||
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
|
||||
):
|
||||
return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
def _put(
|
||||
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
|
||||
):
|
||||
return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs):
|
||||
def _delete(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs):
|
||||
return self._request("DELETE", endpoint, query_params=query_params, **kwargs)
|
||||
|
||||
def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs):
|
||||
def _patch(
|
||||
self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs
|
||||
):
|
||||
return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs)
|
||||
|
||||
|
||||
@ -99,7 +105,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def process_response(self, response: Response) -> dict | bytes | list | None | Generator:
|
||||
def process_response(self, response: Response) -> dict[str, Any] | bytes | list[Any] | None | Generator:
|
||||
if response.status_code == 401:
|
||||
raise WaterCrawlAuthenticationError(response)
|
||||
|
||||
@ -186,7 +192,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
yield from generator
|
||||
|
||||
def get_crawl_request_results(
|
||||
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None
|
||||
self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict[str, Any] | None = None
|
||||
):
|
||||
query_params = query_params or {}
|
||||
query_params.update({"page": page or 1, "page_size": page_size or 25})
|
||||
@ -210,7 +216,7 @@ class WaterCrawlAPIClient(BaseAPIClient):
|
||||
if event_data["type"] == "result":
|
||||
return event_data["data"]
|
||||
|
||||
def download_result(self, result_object: dict):
|
||||
def download_result(self, result_object: dict[str, Any]):
|
||||
response = httpx.get(result_object["result"], timeout=None)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
@ -120,7 +120,7 @@ class WaterCrawlProvider:
|
||||
}
|
||||
|
||||
def _get_results(
|
||||
self, crawl_request_id: str, query_params: dict | None = None
|
||||
self, crawl_request_id: str, query_params: dict[str, Any] | None = None
|
||||
) -> Generator[WatercrawlDocumentData, None, None]:
|
||||
page = 0
|
||||
page_size = 100
|
||||
|
||||
@ -875,7 +875,11 @@ class DatasetRetrieval:
|
||||
return retrieval_resource_list
|
||||
|
||||
def _on_retrieval_end(
|
||||
self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None
|
||||
self,
|
||||
flask_app: Flask,
|
||||
documents: list[Document],
|
||||
message_id: str | None = None,
|
||||
timer: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Handle retrieval end."""
|
||||
with flask_app.app_context():
|
||||
@ -980,7 +984,7 @@ class DatasetRetrieval:
|
||||
|
||||
self._send_trace_task(message_id, documents, timer)
|
||||
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None):
|
||||
def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None):
|
||||
"""Send trace task if trace manager is available."""
|
||||
trace_manager: TraceQueueManager | None = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
@ -1142,7 +1146,7 @@ class DatasetRetrieval:
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
) -> list[DatasetRetrieverBaseTool] | None:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
@ -1337,7 +1341,7 @@ class DatasetRetrieval:
|
||||
metadata_filtering_mode: str,
|
||||
metadata_model_config: ModelConfig,
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]:
|
||||
document_query = select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id.in_(dataset_ids),
|
||||
@ -1417,7 +1421,7 @@ class DatasetRetrieval:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition
|
||||
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
|
||||
def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str:
|
||||
if not inputs:
|
||||
return text
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
|
||||
return _case_routing
|
||||
|
||||
|
||||
def __getattr__(name: str) -> dict:
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Lazy module-level access to routing tables."""
|
||||
if name == "CASE_ROUTING":
|
||||
return _get_case_routing()
|
||||
|
||||
@ -198,7 +198,7 @@ class Tool(ABC):
|
||||
message=ToolInvokeMessage.TextMessage(text=text),
|
||||
)
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage:
|
||||
def create_blob_message(self, blob: bytes, meta: dict[str, Any] | None = None) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
||||
@ -212,7 +212,7 @@ class Tool(ABC):
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage:
|
||||
def create_json_message(self, object: dict[str, Any], suppress_output: bool = False) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a json message
|
||||
"""
|
||||
|
||||
@ -149,7 +149,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict | list
|
||||
json_object: dict[str, Any] | list[Any]
|
||||
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
@ -337,7 +337,7 @@ class ToolParameter(PluginParameter):
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: str | None = None
|
||||
# MCP object and array type parameters use this field to store the schema
|
||||
input_schema: dict | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
@ -463,7 +463,7 @@ class ToolInvokeMeta(BaseModel):
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: str | None = None
|
||||
tool_config: dict | None = None
|
||||
tool_config: dict[str, Any] | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> ToolInvokeMeta:
|
||||
|
||||
@ -85,7 +85,8 @@ class ToolEngine:
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
|
||||
def message_callback(
|
||||
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
|
||||
invocation_meta_dict: dict[str, Any],
|
||||
messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None],
|
||||
):
|
||||
for message in messages:
|
||||
if isinstance(message, ToolInvokeMeta):
|
||||
@ -200,7 +201,7 @@ class ToolEngine:
|
||||
@staticmethod
|
||||
def _invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: dict,
|
||||
tool_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
|
||||
@ -33,7 +33,7 @@ class DatasetRetrieverTool(Tool):
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
inputs: dict[str, Any],
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Generator
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
@ -50,7 +51,7 @@ def safe_json_value(v):
|
||||
return v
|
||||
|
||||
|
||||
def safe_json_dict(d: dict):
|
||||
def safe_json_dict(d: dict[str, Any]):
|
||||
if not isinstance(d, dict):
|
||||
raise TypeError("safe_json_dict() expects a dictionary (dict) as input")
|
||||
return {k: safe_json_value(v) for k, v in d.items()}
|
||||
@ -196,11 +197,11 @@ class ToolFileMessageTransformer:
|
||||
|
||||
@staticmethod
|
||||
def _with_tool_file_meta(
|
||||
meta: dict | None,
|
||||
meta: dict[str, Any] | None,
|
||||
*,
|
||||
tool_file_id: str | None = None,
|
||||
url: str | None = None,
|
||||
) -> dict:
|
||||
) -> dict[str, Any]:
|
||||
normalized_meta = meta.copy() if meta is not None else {}
|
||||
resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url)
|
||||
if resolved_tool_file_id and "tool_file_id" not in normalized_meta:
|
||||
|
||||
@ -32,7 +32,7 @@ class OpenAPISpecDict(TypedDict):
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None
|
||||
openapi: Mapping[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
@ -236,7 +236,7 @@ class ApiBasedToolSchemaParser:
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
def _get_tool_parameter_type(parameter: dict[str, Any]) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
@ -265,7 +265,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
yaml: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
@ -278,14 +278,14 @@ class ApiBasedToolSchemaParser:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
openapi: dict[str, Any] = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
swagger: dict[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> OpenAPISpecDict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
@ -351,7 +351,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
json: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
@ -392,7 +392,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
content: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None
|
||||
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
@ -277,7 +277,7 @@ class WorkflowTool(Tool):
|
||||
session.expunge(app)
|
||||
return app
|
||||
|
||||
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
|
||||
def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""
|
||||
transform the tool parameters
|
||||
|
||||
@ -323,7 +323,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return parameters_result, files
|
||||
|
||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
|
||||
def _extract_files(self, outputs: dict[str, Any]) -> tuple[dict[str, Any], list[File]]:
|
||||
"""
|
||||
extract files from the result
|
||||
|
||||
@ -355,7 +355,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return result, files
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict):
|
||||
def _update_file_mapping(self, file_dict: dict[str, Any]):
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
match transfer_method:
|
||||
|
||||
@ -43,15 +43,20 @@ class IndexProcessorProtocol(Protocol):
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
summary_index_setting: dict[str, Any] | None = None,
|
||||
) -> IndexingResultDict: ...
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
self,
|
||||
chunks: Any,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
chunk_structure: str,
|
||||
summary_index_setting: dict[str, Any] | None,
|
||||
) -> Preview: ...
|
||||
|
||||
|
||||
class SummaryIndexServiceProtocol(Protocol):
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict[str, Any] | None = None
|
||||
) -> None: ...
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import NodeType
|
||||
@ -16,7 +16,7 @@ class TriggerScheduleNodeData(BaseNodeData):
|
||||
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
|
||||
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
|
||||
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
|
||||
visual_config: dict | None = Field(default=None, description="Visual configuration details")
|
||||
visual_config: dict[str, Any] | None = Field(default=None, description="Visual configuration details")
|
||||
timezone: str = Field(default="UTC", description="Timezone for schedule execution")
|
||||
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def generate_file_var(self, param_name: str, file: dict):
|
||||
def generate_file_var(self, param_name: str, file: dict[str, Any]):
|
||||
file_id = resolve_file_record_id(file.get("reference") or file.get("related_id"))
|
||||
transfer_method_value = file.get("transfer_method")
|
||||
if transfer_method_value:
|
||||
@ -147,7 +147,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
|
||||
continue
|
||||
elif self.node_data.content_type == ContentType.BINARY:
|
||||
raw_data: dict = webhook_data.get("body", {}).get("raw", {})
|
||||
raw_data: dict[str, Any] = webhook_data.get("body", {}).get("raw", {})
|
||||
file_var = self.generate_file_var(param_name, raw_data)
|
||||
if file_var:
|
||||
outputs[param_name] = file_var
|
||||
|
||||
@ -10,6 +10,7 @@ import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
def validate_config(cls, values: dict[str, Any]):
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
|
||||
@ -65,7 +65,7 @@ class FileMetadata:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> FileMetadata:
|
||||
def from_dict(cls, data: dict[str, Any]) -> FileMetadata:
|
||||
"""Create instance from dictionary"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
@ -459,7 +459,7 @@ class FileLifecycleManager:
|
||||
newest_file=None,
|
||||
)
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
def _create_version_backup(self, filename: str, metadata: dict[str, Any]):
|
||||
"""Create version backup"""
|
||||
try:
|
||||
# Read current file content
|
||||
@ -487,7 +487,7 @@ class FileLifecycleManager:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
def _save_metadata(self, metadata_dict: dict[str, Any]):
|
||||
"""Save metadata file"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user