Merge commit '9c339239' into sandboxed-agent-rebase

Made-with: Cursor

# Conflicts:
#	api/README.md
#	api/controllers/console/app/workflow_draft_variable.py
#	api/core/agent/cot_agent_runner.py
#	api/core/agent/fc_agent_runner.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/plugin/backwards_invocation/model.py
#	api/core/prompt/advanced_prompt_transform.py
#	api/core/workflow/nodes/base/node.py
#	api/core/workflow/nodes/llm/llm_utils.py
#	api/core/workflow/nodes/llm/node.py
#	api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
#	api/core/workflow/nodes/question_classifier/question_classifier_node.py
#	api/core/workflow/runtime/graph_runtime_state.py
#	api/extensions/storage/base_storage.py
#	api/factories/variable_factory.py
#	api/pyproject.toml
#	api/services/variable_truncator.py
#	api/uv.lock
#	web/app/account/oauth/authorize/page.tsx
#	web/app/components/app/configuration/config-var/config-modal/field.tsx
#	web/app/components/base/alert.tsx
#	web/app/components/base/chat/chat/answer/human-input-content/executed-action.tsx
#	web/app/components/base/chat/chat/answer/more.tsx
#	web/app/components/base/chat/chat/answer/operation.tsx
#	web/app/components/base/chat/chat/answer/workflow-process.tsx
#	web/app/components/base/chat/chat/citation/index.tsx
#	web/app/components/base/chat/chat/citation/popup.tsx
#	web/app/components/base/chat/chat/citation/progress-tooltip.tsx
#	web/app/components/base/chat/chat/citation/tooltip.tsx
#	web/app/components/base/chat/chat/question.tsx
#	web/app/components/base/chat/embedded-chatbot/inputs-form/index.tsx
#	web/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown.tsx
#	web/app/components/base/markdown-blocks/form.tsx
#	web/app/components/base/prompt-editor/plugins/hitl-input-block/component-ui.tsx
#	web/app/components/base/tag-management/panel.tsx
#	web/app/components/base/tag-management/trigger.tsx
#	web/app/components/header/account-setting/index.tsx
#	web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx
#	web/app/signin/utils/post-login-redirect.ts
#	web/eslint-suppressions.json
#	web/package.json
#	web/pnpm-lock.yaml
This commit is contained in:
Novice 2026-03-23 09:00:45 +08:00
commit cccff6768a
No known key found for this signature in database
GPG Key ID: A253106A7475AA3E
1009 changed files with 76072 additions and 18166 deletions

View File

@ -0,0 +1,168 @@
---
name: backend-code-review
description: Review backend code for quality, security, maintainability, and best practices based on established checklist rules. Use when the user requests a review, analysis, or improvement of backend files (e.g., `.py`) under the `api/` directory. Do NOT use for frontend files (e.g., `.tsx`, `.ts`, `.js`). Supports pending-change review, code snippets review, and file-focused review.
---
# Backend Code Review
## When to use this skill
Use this skill whenever the user asks to **review, analyze, or improve** backend code (e.g., `.py`) under the `api/` directory. Supports the following review modes:
- **Pending-change review**: when the user asks to review current changes (inspect staged/working-tree files slated for commit to get the changes).
- **Code snippets review**: when the user pastes code snippets (e.g., a function/class/module excerpt) into the chat and asks for a review.
- **File-focused review**: when the user points to specific files and asks for a review of those files (one file or a small, explicit set of files, e.g., `api/...`, `api/app.py`).
Do NOT use this skill when:
- The request is about frontend code or UI (e.g., `.tsx`, `.ts`, `.js`, `web/`).
- The user is not asking for a review/analysis/improvement of backend code.
- The scope is not under `api/` (unless the user explicitly asks to review backend-related changes outside `api/`).
## How to use this skill
Follow these steps when using this skill:
1. **Identify the review mode** (pending-change vs snippet vs file-focused) based on the users input. Keep the scope tight: review only what the user provided or explicitly referenced.
2. Follow the rules defined in **Checklist** to perform the review. If no Checklist rule matches, apply **General Review Rules** as a fallback to perform the best-effort review.
3. Compose the final output strictly follow the **Required Output Format**.
Notes when using this skill:
- Always include actionable fixes or suggestions (including possible code snippets).
- Use best-effort `File:Line` references when a file path and line numbers are available; otherwise, use the most specific identifier you can.
## Checklist
- db schema design: if the review scope includes code/files under `api/models/` or `api/migrations/`, follow [references/db-schema-rule.md](references/db-schema-rule.md) to perform the review
- architecture: if the review scope involves controller/service/core-domain/libs/model layering, dependency direction, or moving responsibilities across modules, follow [references/architecture-rule.md](references/architecture-rule.md) to perform the review
- repositories abstraction: if the review scope contains table/model operations (e.g., `select(...)`, `session.execute(...)`, joins, CRUD) and is not under `api/repositories`, `api/core/repositories`, or `api/extensions/*/repositories/`, follow [references/repositories-rule.md](references/repositories-rule.md) to perform the review
- sqlalchemy patterns: if the review scope involves SQLAlchemy session/query usage, db transaction/crud usage, or raw SQL usage, follow [references/sqlalchemy-rule.md](references/sqlalchemy-rule.md) to perform the review
## General Review Rules
### 1. Security Review
Check for:
- SQL injection vulnerabilities
- Server-Side Request Forgery (SSRF)
- Command injection
- Insecure deserialization
- Hardcoded secrets/credentials
- Improper authentication/authorization
- Insecure direct object references
### 2. Performance Review
Check for:
- N+1 queries
- Missing database indexes
- Memory leaks
- Blocking operations in async code
- Missing caching opportunities
### 3. Code Quality Review
Check for:
- Code forward compatibility
- Code duplication (DRY violations)
- Functions doing too much (SRP violations)
- Deep nesting / complex conditionals
- Magic numbers/strings
- Poor naming
- Missing error handling
- Incomplete type coverage
### 4. Testing Review
Check for:
- Missing test coverage for new code
- Tests that don't test behavior
- Flaky test patterns
- Missing edge cases
## Required Output Format
When this skill invoked, the response must exactly follow one of the two templates:
### Template A (any findings)
```markdown
# Code Review Summary
Found <X> critical issues need to be fixed:
## 🔴 Critical (Must Fix)
### 1. <brief description of the issue>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<detailed explanation and references of the issue>
#### Suggested Fix
1. <brief description of suggested fix>
2. <code example> (optional, omit if not applicable)
---
... (repeat for each critical issue) ...
Found <Y> suggestions for improvement:
## 🟡 Suggestions (Should Consider)
### 1. <brief description of the suggestion>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<detailed explanation and references of the suggestion>
#### Suggested Fix
1. <brief description of suggested fix>
2. <code example> (optional, omit if not applicable)
---
... (repeat for each suggestion) ...
Found <Z> optional nits:
## 🟢 Nits (Optional)
### 1. <brief description of the nit>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<explanation and references of the optional nit>
#### Suggested Fix
- <minor suggestions>
---
... (repeat for each nits) ...
## ✅ What's Good
- <Positive feedback on good patterns>
```
- If there are no critical issues or suggestions or option nits or good points, just omit that section.
- If the issue number is more than 10, summarize as "Found 10+ critical issues/suggestions/optional nits" and only output the first 10 items.
- Don't compress the blank lines between sections; keep them as-is for readability.
- If there is any issue requires code changes, append a brief follow-up question to ask whether the user wants to apply the fix(es) after the structured output. For example: "Would you like me to use the Suggested fix(es) to address these issues?"
### Template B (no issues)
```markdown
## Code Review Summary
✅ No issues found.
```

View File

@ -0,0 +1,91 @@
# Rule Catalog — Architecture
## Scope
- Covers: controller/service/core-domain/libs/model layering, dependency direction, responsibility placement, observability-friendly flow.
## Rules
### Keep business logic out of controllers
- Category: maintainability
- Severity: critical
- Description: Controllers should parse input, call services, and return serialized responses. Business decisions inside controllers make behavior hard to reuse and test.
- Suggested fix: Move domain/business logic into the service or core/domain layer. Keep controller handlers thin and orchestration-focused.
- Example:
- Bad:
```python
@bp.post("/apps/<app_id>/publish")
def publish_app(app_id: str):
payload = request.get_json() or {}
if payload.get("force") and current_user.role != "admin":
raise ValueError("only admin can force publish")
app = App.query.get(app_id)
app.status = "published"
db.session.commit()
return {"result": "ok"}
```
- Good:
```python
@bp.post("/apps/<app_id>/publish")
def publish_app(app_id: str):
payload = PublishRequest.model_validate(request.get_json() or {})
app_service.publish_app(app_id=app_id, force=payload.force, actor_id=current_user.id)
return {"result": "ok"}
```
### Preserve layer dependency direction
- Category: best practices
- Severity: critical
- Description: Controllers may depend on services, and services may depend on core/domain abstractions. Reversing this direction (for example, core importing controller/web modules) creates cycles and leaks transport concerns into domain code.
- Suggested fix: Extract shared contracts into core/domain or service-level modules and make upper layers depend on lower, not the reverse.
- Example:
- Bad:
```python
# core/policy/publish_policy.py
from controllers.console.app import request_context
def can_publish() -> bool:
return request_context.current_user.is_admin
```
- Good:
```python
# core/policy/publish_policy.py
def can_publish(role: str) -> bool:
return role == "admin"
# service layer adapts web/user context to domain input
allowed = can_publish(role=current_user.role)
```
### Keep libs business-agnostic
- Category: maintainability
- Severity: critical
- Description: Modules under `api/libs/` should remain reusable, business-agnostic building blocks. They must not encode product/domain-specific rules, workflow orchestration, or business decisions.
- Suggested fix:
- If business logic appears in `api/libs/`, extract it into the appropriate `services/` or `core/` module and keep `libs` focused on generic, cross-cutting helpers.
- Keep `libs` dependencies clean: avoid importing service/controller/domain-specific modules into `api/libs/`.
- Example:
- Bad:
```python
# api/libs/conversation_filter.py
from services.conversation_service import ConversationService
def should_archive_conversation(conversation, tenant_id: str) -> bool:
# Domain policy and service dependency are leaking into libs.
service = ConversationService()
if service.has_paid_plan(tenant_id):
return conversation.idle_days > 90
return conversation.idle_days > 30
```
- Good:
```python
# api/libs/datetime_utils.py (business-agnostic helper)
def older_than_days(idle_days: int, threshold_days: int) -> bool:
return idle_days > threshold_days
# services/conversation_service.py (business logic stays in service/core)
from libs.datetime_utils import older_than_days
def should_archive_conversation(conversation, tenant_id: str) -> bool:
threshold_days = 90 if has_paid_plan(tenant_id) else 30
return older_than_days(conversation.idle_days, threshold_days)
```

View File

@ -0,0 +1,157 @@
# Rule Catalog — DB Schema Design
## Scope
- Covers: model/base inheritance, schema boundaries in model properties, tenant-aware schema design, index redundancy checks, dialect portability in models, and cross-database compatibility in migrations.
- Does NOT cover: session lifecycle, transaction boundaries, and query execution patterns (handled by `sqlalchemy-rule.md`).
## Rules
### Do not query other tables inside `@property`
- Category: [maintainability, performance]
- Severity: critical
- Description: A model `@property` must not open sessions or query other tables. This hides dependencies across models, tightly couples schema objects to data access, and can cause N+1 query explosions when iterating collections.
- Suggested fix:
- Keep model properties pure and local to already-loaded fields.
- Move cross-table data fetching to service/repository methods.
- For list/batch reads, fetch required related data explicitly (join/preload/bulk query) before rendering derived values.
- Example:
- Bad:
```python
class Conversation(TypeBase):
__tablename__ = "conversations"
@property
def app_name(self) -> str:
with Session(db.engine, expire_on_commit=False) as session:
app = session.execute(select(App).where(App.id == self.app_id)).scalar_one()
return app.name
```
- Good:
```python
class Conversation(TypeBase):
__tablename__ = "conversations"
@property
def display_title(self) -> str:
return self.name or "Untitled"
# Service/repository layer performs explicit batch fetch for related App rows.
```
### Prefer including `tenant_id` in model definitions
- Category: maintainability
- Severity: suggestion
- Description: In multi-tenant domains, include `tenant_id` in schema definitions whenever the entity belongs to tenant-owned data. This improves data isolation safety and keeps future partitioning/sharding strategies practical as data volume grows.
- Suggested fix:
- Add a `tenant_id` column and ensure related unique/index constraints include tenant dimension when applicable.
- Propagate `tenant_id` through service/repository contracts to keep access paths tenant-aware.
- Exception: if a table is explicitly designed as non-tenant-scoped global metadata, document that design decision clearly.
- Example:
- Bad:
```python
from sqlalchemy.orm import Mapped
class Dataset(TypeBase):
__tablename__ = "datasets"
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
```
- Good:
```python
from sqlalchemy.orm import Mapped
class Dataset(TypeBase):
__tablename__ = "datasets"
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
```
### Detect and avoid duplicate/redundant indexes
- Category: performance
- Severity: suggestion
- Description: Review index definitions for leftmost-prefix redundancy. For example, index `(a, b, c)` can safely cover most lookups for `(a, b)`. Keeping both may increase write overhead and can mislead the optimizer into suboptimal execution plans.
- Suggested fix:
- Before adding an index, compare against existing composite indexes by leftmost-prefix rules.
- Drop or avoid creating redundant prefixes unless there is a proven query-pattern need.
- Apply the same review standard in both model `__table_args__` and migration index DDL.
- Example:
- Bad:
```python
__table_args__ = (
sa.Index("idx_msg_tenant_app", "tenant_id", "app_id"),
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
)
```
- Good:
```python
__table_args__ = (
# Keep the wider index unless profiling proves a dedicated short index is needed.
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
)
```
### Avoid PostgreSQL-only dialect usage in models; wrap in `models.types`
- Category: maintainability
- Severity: critical
- Description: Model/schema definitions should avoid PostgreSQL-only constructs directly in business models. When database-specific behavior is required, encapsulate it in `api/models/types.py` using both PostgreSQL and MySQL dialect implementations, then consume that abstraction from model code.
- Suggested fix:
- Do not directly place dialect-only types/operators in model columns when a portable wrapper can be used.
- Add or extend wrappers in `models.types` (for example, `AdjustedJSON`, `LongText`, `BinaryData`) to normalize behavior across PostgreSQL and MySQL.
- Example:
- Bad:
```python
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped
class ToolConfig(TypeBase):
__tablename__ = "tool_configs"
config: Mapped[dict] = mapped_column(JSONB, nullable=False)
```
- Good:
```python
from sqlalchemy.orm import Mapped
from models.types import AdjustedJSON
class ToolConfig(TypeBase):
__tablename__ = "tool_configs"
config: Mapped[dict] = mapped_column(AdjustedJSON(), nullable=False)
```
### Guard migration incompatibilities with dialect checks and shared types
- Category: maintainability
- Severity: critical
- Description: Migration scripts under `api/migrations/versions/` must account for PostgreSQL/MySQL incompatibilities explicitly. For dialect-sensitive DDL or defaults, branch on the active dialect (for example, `conn.dialect.name == "postgresql"`), and prefer reusable compatibility abstractions from `models.types` where applicable.
- Suggested fix:
- In migration upgrades/downgrades, bind connection and branch by dialect for incompatible SQL fragments.
- Reuse `models.types` wrappers in column definitions when that keeps behavior aligned with runtime models.
- Avoid one-dialect-only migration logic unless there is a documented, deliberate compatibility exception.
- Example:
- Bad:
```python
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
batch_op.add_column(
sa.Column(
"data_source_type",
sa.String(255),
server_default=sa.text("'database'::character varying"),
nullable=False,
)
)
```
- Good:
```python
def _is_pg(conn) -> bool:
return conn.dialect.name == "postgresql"
conn = op.get_bind()
default_expr = sa.text("'database'::character varying") if _is_pg(conn) else sa.text("'database'")
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
batch_op.add_column(
sa.Column("data_source_type", sa.String(255), server_default=default_expr, nullable=False)
)
```

View File

@ -0,0 +1,61 @@
# Rule Catalog - Repositories Abstraction
## Scope
- Covers: when to reuse existing repository abstractions, when to introduce new repositories, and how to preserve dependency direction between service/core and infrastructure implementations.
- Does NOT cover: SQLAlchemy session lifecycle and query-shape specifics (handled by `sqlalchemy-rule.md`), and table schema/migration design (handled by `db-schema-rule.md`).
## Rules
### Introduce repositories abstraction
- Category: maintainability
- Severity: suggestion
- Description: If a table/model already has a repository abstraction, all reads/writes/queries for that table should use the existing repository. If no repository exists, introduce one only when complexity justifies it, such as large/high-volume tables, repeated complex query logic, or likely storage-strategy variation.
- Suggested fix:
- First check `api/repositories`, `api/core/repositories`, and `api/extensions/*/repositories/` to verify whether the table/model already has a repository abstraction. If it exists, route all operations through it and add missing repository methods instead of bypassing it with ad-hoc SQLAlchemy access.
- If no repository exists, add one only when complexity warrants it (for example, repeated complex queries, large data domains, or multiple storage strategies), while preserving dependency direction (service/core depends on abstraction; infra provides implementation).
- Example:
- Bad:
```python
# Existing repository is ignored and service uses ad-hoc table queries.
class AppService:
def archive_app(self, app_id: str, tenant_id: str) -> None:
app = self.session.execute(
select(App).where(App.id == app_id, App.tenant_id == tenant_id)
).scalar_one()
app.archived = True
self.session.commit()
```
- Good:
```python
# Case A: Existing repository must be reused for all table operations.
class AppService:
def archive_app(self, app_id: str, tenant_id: str) -> None:
app = self.app_repo.get_by_id(app_id=app_id, tenant_id=tenant_id)
app.archived = True
self.app_repo.save(app)
# If the query is missing, extend the existing abstraction.
active_apps = self.app_repo.list_active_for_tenant(tenant_id=tenant_id)
```
- Bad:
```python
# No repository exists, but large-domain query logic is scattered in service code.
class ConversationService:
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
...
# many filters/joins/pagination variants duplicated across services
```
- Good:
```python
# Case B: Introduce repository for large/complex domains or storage variation.
class ConversationRepository(Protocol):
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: ...
class SqlAlchemyConversationRepository:
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
...
class ConversationService:
def __init__(self, conversation_repo: ConversationRepository):
self.conversation_repo = conversation_repo
```

View File

@ -0,0 +1,139 @@
# Rule Catalog — SQLAlchemy Patterns
## Scope
- Covers: SQLAlchemy session and transaction lifecycle, query construction, tenant scoping, raw SQL boundaries, and write-path concurrency safeguards.
- Does NOT cover: table/model schema and migration design details (handled by `db-schema-rule.md`).
## Rules
### Use Session context manager with explicit transaction control behavior
- Category: best practices
- Severity: critical
- Description: Session and transaction lifecycle must be explicit and bounded on write paths. Missing commits can silently drop intended updates, while ad-hoc or long-lived transactions increase contention, lock duration, and deadlock risk.
- Suggested fix:
- Use **explicit `session.commit()`** after completing a related write unit.
- Or use **`session.begin()` context manager** for automatic commit/rollback on a scoped block.
- Keep transaction windows short: avoid network I/O, heavy computation, or unrelated work inside the transaction.
- Example:
- Bad:
```python
# Missing commit: write may never be persisted.
with Session(db.engine, expire_on_commit=False) as session:
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
# Long transaction: external I/O inside a DB transaction.
with Session(db.engine, expire_on_commit=False) as session, session.begin():
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
call_external_api()
```
- Good:
```python
# Option 1: explicit commit.
with Session(db.engine, expire_on_commit=False) as session:
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
session.commit()
# Option 2: scoped transaction with automatic commit/rollback.
with Session(db.engine, expire_on_commit=False) as session, session.begin():
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
# Keep non-DB work outside transaction scope.
call_external_api()
```
### Enforce tenant_id scoping on shared-resource queries
- Category: security
- Severity: critical
- Description: Reads and writes against shared tables must be scoped by `tenant_id` to prevent cross-tenant data leakage or corruption.
- Suggested fix: Add `tenant_id` predicate to all tenant-owned entity queries and propagate tenant context through service/repository interfaces.
- Example:
- Bad:
```python
stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.execute(stmt).scalar_one_or_none()
```
- Good:
```python
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
workflow = session.execute(stmt).scalar_one_or_none()
```
### Prefer SQLAlchemy expressions over raw SQL by default
- Category: maintainability
- Severity: suggestion
- Description: Raw SQL should be exceptional. ORM/Core expressions are easier to evolve, safer to compose, and more consistent with the codebase.
- Suggested fix: Rewrite straightforward raw SQL into SQLAlchemy `select/update/delete` expressions; keep raw SQL only when required by clear technical constraints.
- Example:
- Bad:
```python
row = session.execute(
text("SELECT * FROM workflows WHERE id = :id AND tenant_id = :tenant_id"),
{"id": workflow_id, "tenant_id": tenant_id},
).first()
```
- Good:
```python
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
row = session.execute(stmt).scalar_one_or_none()
```
### Protect write paths with concurrency safeguards
- Category: quality
- Severity: critical
- Description: Multi-writer paths without explicit concurrency control can silently overwrite data. Choose the safeguard based on contention level, lock scope, and throughput cost instead of defaulting to one strategy.
- Suggested fix:
- **Optimistic locking**: Use when contention is usually low and retries are acceptable. Add a version (or updated_at) guard in `WHERE` and treat `rowcount == 0` as a conflict.
- **Redis distributed lock**: Use when the critical section spans multiple steps/processes (or includes non-DB side effects) and you need cross-worker mutual exclusion.
- **SELECT ... FOR UPDATE**: Use when contention is high on the same rows and strict in-transaction serialization is required. Keep transactions short to reduce lock wait/deadlock risk.
- In all cases, scope by `tenant_id` and verify affected row counts for conditional writes.
- Example:
- Bad:
```python
# No tenant scope, no conflict detection, and no lock on a contested write path.
session.execute(update(WorkflowRun).where(WorkflowRun.id == run_id).values(status="cancelled"))
session.commit() # silently overwrites concurrent updates
```
- Good:
```python
# 1) Optimistic lock (low contention, retry on conflict)
result = session.execute(
update(WorkflowRun)
.where(
WorkflowRun.id == run_id,
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.version == expected_version,
)
.values(status="cancelled", version=WorkflowRun.version + 1)
)
if result.rowcount == 0:
raise WorkflowStateConflictError("stale version, retry")
# 2) Redis distributed lock (cross-worker critical section)
lock_name = f"workflow_run_lock:{tenant_id}:{run_id}"
with redis_client.lock(lock_name, timeout=20):
session.execute(
update(WorkflowRun)
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
.values(status="cancelled")
)
session.commit()
# 3) Pessimistic lock with SELECT ... FOR UPDATE (high contention)
run = session.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
.with_for_update()
).scalar_one()
run.status = "cancelled"
session.commit()
```

View File

@ -0,0 +1 @@
../../.agents/skills/backend-code-review

View File

@ -0,0 +1,88 @@
name: Comment with Pyrefly Diff
on:
workflow_run:
workflows:
- Pyrefly Diff Check
types:
- completed
permissions: {}
jobs:
comment:
name: Comment PR with pyrefly diff
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
issues: write
pull-requests: write
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Download pyrefly diff artifact
uses: actions/github-script@v8
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: ${{ github.event.workflow_run.id }},
});
const match = artifacts.data.artifacts.find((artifact) =>
artifact.name === 'pyrefly_diff'
);
if (!match) {
throw new Error('pyrefly_diff artifact not found');
}
const download = await github.rest.actions.downloadArtifact({
owner: context.repo.owner,
repo: context.repo.repo,
artifact_id: match.id,
archive_format: 'zip',
});
fs.writeFileSync('pyrefly_diff.zip', Buffer.from(download.data));
- name: Unzip artifact
run: unzip -o pyrefly_diff.zip
- name: Post comment
uses: actions/github-script@v8
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
let diff = fs.readFileSync('pyrefly_diff.txt', { encoding: 'utf8' });
let prNumber = null;
try {
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
} catch (err) {
// Fallback to workflow_run payload if artifact is missing or incomplete.
const prs = context.payload.workflow_run.pull_requests || [];
if (prs.length > 0 && prs[0].number) {
prNumber = prs[0].number;
}
}
if (!prNumber) {
throw new Error('PR number not found in artifact or workflow_run payload');
}
const MAX_CHARS = 65000;
if (diff.length > MAX_CHARS) {
diff = diff.slice(0, MAX_CHARS);
diff = diff.slice(0, diff.lastIndexOf('\\n'));
diff += '\\n\\n... (truncated) ...';
}
const body = diff.trim()
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});

100
.github/workflows/pyrefly-diff.yml vendored Normal file
View File

@ -0,0 +1,100 @@
name: Pyrefly Diff Check
on:
pull_request:
paths:
- 'api/**/*.py'
permissions:
contents: read
jobs:
pyrefly-diff:
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
- name: Install dependencies
run: uv sync --project api --dev
- name: Prepare diagnostics extractor
run: |
git show ${{ github.event.pull_request.head.sha }}:api/libs/pyrefly_diagnostics.py > /tmp/pyrefly_diagnostics.py
- name: Run pyrefly on PR branch
run: |
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_pr.txt || true
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly on base branch
run: |
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_base.txt || true
- name: Compute diff
run: |
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload pyrefly diff
uses: actions/upload-artifact@v4
with:
name: pyrefly_diff
path: |
pyrefly_diff.txt
pr_number.txt
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@v8
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs');
let diff = fs.readFileSync('pyrefly_diff.txt', { encoding: 'utf8' });
const prNumber = context.payload.pull_request.number;
const MAX_CHARS = 65000;
if (diff.length > MAX_CHARS) {
diff = diff.slice(0, MAX_CHARS);
diff = diff.slice(0, diff.lastIndexOf('\n'));
diff += '\n\n... (truncated) ...';
}
const body = diff.trim()
? [
'### Pyrefly Diff',
'<details>',
'<summary>base → PR</summary>',
'',
'```diff',
diff,
'```',
'</details>',
].join('\n')
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});

View File

@ -3,14 +3,22 @@ name: Web Tests
on:
workflow_call:
permissions:
contents: read
concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: Web Tests
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
defaults:
run:
shell: bash
@ -39,7 +47,58 @@ jobs:
run: pnpm install --frozen-lockfile
- name: Run tests
run: pnpm test:ci
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v6
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*
include-hidden-files: true
retention-days: 1
merge-reports:
name: Merge Test Reports
if: ${{ !cancelled() }}
needs: [test]
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Download blob reports
uses: actions/download-artifact@v6
with:
path: web/.vitest-reports
pattern: blob-report-*
merge-multiple: true
- name: Merge reports
run: pnpm vitest --merge-reports --coverage --silent=passed-only
- name: Coverage Summary
if: always()

View File

@ -68,10 +68,9 @@ lint:
@echo "✅ Linting complete"
type-check:
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
@echo "📝 Running type checks (basedpyright + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@cd api && uv run ty check
@echo "✅ Type checks complete"
test:
@ -132,7 +131,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
@echo " make type-check - Run type checks (basedpyright, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@ -1,9 +1,5 @@
![cover-v5-optimized](./images/GitHub_README_if.png)
<p align="center">
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
</p>
<p align="center">
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·

View File

@ -29,6 +29,8 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
@ -50,14 +52,9 @@ forbidden_modules =
allow_indirect_imports = True
ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@ -91,7 +88,6 @@ forbidden_modules =
core.logging
core.mcp
core.memory
core.model_manager
core.moderation
core.ops
core.plugin
@ -105,29 +101,16 @@ forbidden_modules =
core.variables
ignore_imports =
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability
core.workflow.nodes.agent.agent_node -> core.model_manager
core.workflow.nodes.agent.agent_node -> core.provider_manager
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
core.workflow.nodes.datasource.datasource_node -> models.model
core.workflow.nodes.datasource.datasource_node -> models.tools
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.entities -> configs
core.workflow.nodes.http_request.executor -> configs
core.workflow.nodes.http_request.node -> configs
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
@ -140,36 +123,19 @@ ignore_imports =
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.start.entities -> core.app.app_config.entities
core.workflow.nodes.start.start_node -> core.app.app_config.entities
core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.variables.segments
core.workflow.nodes.loop.entities -> core.variables.types
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
core.workflow.nodes.tool.tool_node -> models
core.workflow.nodes.agent.agent_node -> models.model
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
core.workflow.nodes.llm.node -> core.helper.code_executor
core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
@ -178,7 +144,6 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
@ -187,71 +152,13 @@ ignore_imports =
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
core.workflow.nodes.llm.node -> models.dataset
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
core.workflow.nodes.llm.file_saver -> core.tools.signature
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
core.workflow.nodes.tool.tool_node -> core.tools.errors
core.workflow.conversation_variable_updater -> core.variables
core.workflow.graph_engine.entities.commands -> core.variables.variables
core.workflow.nodes.agent.agent_node -> core.variables.segments
core.workflow.nodes.answer.answer_node -> core.variables
core.workflow.nodes.code.code_node -> core.variables.segments
core.workflow.nodes.code.code_node -> core.variables.types
core.workflow.nodes.code.entities -> core.variables.types
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
core.workflow.nodes.document_extractor.node -> core.variables
core.workflow.nodes.document_extractor.node -> core.variables.segments
core.workflow.nodes.http_request.executor -> core.variables.segments
core.workflow.nodes.http_request.node -> core.variables.segments
core.workflow.nodes.human_input.entities -> core.variables.consts
core.workflow.nodes.iteration.iteration_node -> core.variables
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
core.workflow.nodes.list_operator.node -> core.variables
core.workflow.nodes.list_operator.node -> core.variables.segments
core.workflow.nodes.llm.node -> core.variables
core.workflow.nodes.loop.loop_node -> core.variables
core.workflow.nodes.parameter_extractor.entities -> core.variables.types
core.workflow.nodes.parameter_extractor.exc -> core.variables.types
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
core.workflow.nodes.tool.tool_node -> core.variables.segments
core.workflow.nodes.tool.tool_node -> core.variables.variables
core.workflow.nodes.trigger_webhook.node -> core.variables.types
core.workflow.nodes.trigger_webhook.node -> core.variables.variables
core.workflow.nodes.variable_aggregator.entities -> core.variables.types
core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
core.workflow.nodes.variable_assigner.common.helpers -> core.variables
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
core.workflow.nodes.variable_assigner.v1.node -> core.variables
core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
core.workflow.nodes.variable_assigner.v2.node -> core.variables
core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
core.workflow.runtime.read_only_wrappers -> core.variables.segments
core.workflow.runtime.variable_pool -> core.variables
core.workflow.runtime.variable_pool -> core.variables.consts
core.workflow.runtime.variable_pool -> core.variables.segments
core.workflow.runtime.variable_pool -> core.variables.variables
core.workflow.utils.condition.processor -> core.variables
core.workflow.utils.condition.processor -> core.variables.segments
core.workflow.variable_loader -> core.variables
core.workflow.variable_loader -> core.variables.consts
core.workflow.workflow_type_encoder -> core.variables
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@ -259,7 +166,7 @@ ignore_imports =
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums
core.workflow.nodes.llm.llm_utils -> models.provider_ids
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.llm.node -> models.model
core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services

View File

@ -42,7 +42,7 @@ The scripts resolve paths relative to their location, so you can run them from a
1. Set up your application by visiting `http://localhost:3000`.
1. Optional: start the worker service (async tasks, runs from `api`).
1. Start the worker service (async and scheduler tasks, runs from `api`).
```bash
./dev/start-worker

File diff suppressed because one or more lines are too long

View File

@ -765,7 +765,7 @@ class WorkflowTaskStopApi(Resource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -15,11 +15,11 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.file import helpers as file_helpers
from core.workflow.variables.segment_group import SegmentGroup
from core.workflow.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
from core.workflow.variables.types import SegmentType
from extensions.ext_database import db
from factories import variable_factory
from factories.file_factory import build_from_mapping, build_from_mappings
@ -133,11 +133,11 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
}
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value),
full_content=fields.Raw(attribute=_serialize_full_content),
)
_WORKFLOW_DRAFT_VARIABLE_FIELDS = {
**_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
"value": fields.Raw(attribute=_serialize_var_value),
"full_content": fields.Raw(attribute=_serialize_full_content),
}
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"id": fields.String,

View File

@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import (
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.variables.types import SegmentType
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type

View File

@ -44,6 +44,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.app_fields import (
app_detail_fields_with_site,
deleted_tool_fields,
@ -225,7 +226,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -23,6 +23,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_redis import redis_client
from libs import helper
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
@ -100,6 +101,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -36,9 +36,9 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required(view: Callable[P, R]):
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check account initialization
current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
@ -214,9 +214,9 @@ def cloud_utm_record(view: Callable[P, R]):
return decorated
def setup_required(view: Callable[P, R]):
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"

View File

@ -137,7 +137,7 @@ class FilePreviewApi(Resource):
if args.as_attachment:
encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["Content-Type"] = "application/octet-stream"
enforce_download_for_html(
response,

View File

@ -64,6 +64,10 @@ class ToolFileApi(Resource):
if not stream or not tool_file:
raise NotFound("file is not found")
except NotFound:
raise
except Exception:
raise UnsupportedFileTypeError()

View File

@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from core.workflow.variables.input_entities import VariableEntity
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@ -31,6 +31,7 @@ from core.model_runtime.errors.invoke import InvokeError
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import OptionalTimestampField, TimestampField
@ -280,7 +281,7 @@ class WorkflowTaskStopApi(Resource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -3,7 +3,8 @@ from typing import Any
from flask import request
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
@ -17,7 +18,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from libs import helper
from libs.login import current_user
from models import Account
from models.dataset import Pipeline
from models.dataset import Dataset, Pipeline
from models.engine import db
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
from services.file_service import FileService
@ -65,6 +66,12 @@ class DatasourcePluginsApi(DatasetApiResource):
)
def get(self, tenant_id: str, dataset_id: str):
"""Resource for getting datasource plugins."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
# Get query parameter to determine published or draft
is_published: bool = request.args.get("is_published", default=True, type=bool)
@ -104,6 +111,12 @@ class DatasourceNodeRunApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
@ -161,6 +174,12 @@ class PipelineRunApi(DatasetApiResource):
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
if not isinstance(current_user, Account):

View File

@ -24,6 +24,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_redis import redis_client
from libs import helper
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
@ -121,6 +122,6 @@ class WorkflowTaskStopApi(WebApiResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner):
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []

View File

@ -1,7 +1,8 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[

View File

@ -2,12 +2,12 @@ from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
from core.workflow.file import FileUploadConfig
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
from models.model import AppMode
@ -90,61 +90,7 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, v: Any) -> str:
return v or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
class RagPipelineVariableEntity(VariableEntity):
class RagPipelineVariableEntity(WorkflowVariableEntity):
"""
Rag Pipeline Variable Entity.
"""
@ -314,7 +260,7 @@ class AppConfig(BaseModel):
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = []
variables: list[WorkflowVariableEntity] = []
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None

View File

@ -1,6 +1,7 @@
import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from core.app.app_config.entities import RagPipelineVariableEntity
from core.workflow.variables.input_entities import VariableEntity
from models.workflow import Workflow

View File

@ -26,7 +26,6 @@ from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.sandbox import Sandbox
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@ -35,6 +34,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.variables.variables import Variable
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from extensions.ext_redis import redis_client

View File

@ -850,16 +850,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
) -> Generator[StreamResponse, None, None]:
"""Handle retriever resources events."""
self._message_cycle_manager.handle_retriever_resources(event)
return
yield # Make this a generator
yield from ()
def _handle_annotation_reply_event(
self, event: QueueAnnotationReplyEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle annotation reply events."""
self._message_cycle_manager.handle_annotation_reply(event)
return
yield # Make this a generator
yield from ()
def _handle_message_replace_event(
self, event: QueueMessageReplaceEvent, **kwargs

View File

@ -175,7 +175,7 @@ class AgentChatAppRunner(AppRunner):
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")

View File

@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import NodeType
from core.workflow.file import File, FileUploadConfig
@ -12,13 +11,14 @@ from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from core.workflow.variables.input_entities import VariableEntityType
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.app.app_config.entities import VariableEntity
from core.workflow.variables.input_entities import VariableEntity
class BaseAppGenerator:

View File

@ -122,7 +122,7 @@ class AppQueueManager(ABC):
"""Attach the live graph runtime state reference for downstream consumers."""
self._graph_runtime_state = graph_runtime_state
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:

View File

@ -49,7 +49,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import (
@ -62,6 +61,7 @@ from core.workflow.enums import (
from core.workflow.file import FILE_MODEL_IDENTITY, File
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db

View File

@ -11,7 +11,6 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
@ -21,6 +20,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Document, Pipeline

View File

@ -1,12 +1,12 @@
import logging
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.variables import VariableBase
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,5 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]

View File

@ -0,0 +1,110 @@
from __future__ import annotations
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
class DifyCredentialsProvider:
tenant_id: str
provider_manager: ProviderManager
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
self.tenant_id = tenant_id
self.provider_manager = provider_manager or ProviderManager()
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
raise ValueError(f"Provider {provider_name} does not exist.")
provider_model = provider_configuration.get_provider_model(model_type=ModelType.LLM, model=model_name)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
provider_model.raise_for_status()
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model_name)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
return credentials
class DifyModelFactory:
tenant_id: str
model_manager: ModelManager
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
self.tenant_id = tenant_id
self.model_manager = model_manager or ModelManager()
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
return self.model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=provider_name,
model_type=ModelType.LLM,
model=model_name,
)
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
return (
DifyCredentialsProvider(tenant_id=tenant_id),
DifyModelFactory(tenant_id=tenant_id),
)
def fetch_model_config(
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.stop = tuple(stop)
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=completion_params,
stop=stop,
)

93
api/core/app/llm/quota.py Normal file
View File

@ -0,0 +1,93 @@
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -168,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
id=self._message_id,
mode=self._conversation_mode,
message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
@ -181,7 +181,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
@ -294,7 +294,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# handle output moderation
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
self._task_state.llm_result.message.get_text_content()
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
@ -408,7 +408,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer = (
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
PromptTemplateParser.remove_template_variables(llm_result.message.get_text_content().strip())
if llm_result.message.content
else ""
)

View File

@ -1,9 +1,11 @@
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
from .llm_quota import LLMQuotaLayer
from .observability import ObservabilityLayer
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
__all__ = [
"LLMQuotaLayer",
"ObservabilityLayer",
"PersistenceWorkflowInfo",
"WorkflowPersistenceLayer",

View File

@ -0,0 +1,128 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.workflow.enums import NodeType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.base.node import Node
if TYPE_CHECKING:
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
def __init__(self) -> None:
super().__init__()
self._abort_sent = False
@override
def on_graph_start(self) -> None:
self._abort_sent = False
@override
def on_event(self, event: GraphEngineEvent) -> None:
_ = event
@override
def on_graph_end(self, error: Exception | None) -> None:
_ = error
@override
def on_node_run_start(self, node: Node) -> None:
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
ensure_llm_quota_available(model_instance=model_instance)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
deduct_llm_quota(
tenant_id=node.tenant_id,
model_instance=model_instance,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
except Exception:
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
@staticmethod
def _set_stop_event(node: Node) -> None:
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
if stop_event is not None:
stop_event.set()
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
try:
self.command_channel.send_command(
AbortCommand(
command_type=CommandType.ABORT,
reason=reason,
)
)
self._abort_sent = True
except Exception:
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
node.id,
)
return None

View File

@ -1,37 +1,97 @@
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.app.llm.model_access import build_dify_model_access
from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.enums import NodeType, SystemVariableKey
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
from core.workflow.nodes.code.entities import CodeLanguage
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def fetch_memory(
*,
conversation_id: str | None,
app_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
if not node_data_memory or not conversation_id:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
class DefaultWorkflowCodeExecutor:
def execute(
self,
*,
language: CodeLanguage,
code: str,
inputs: Mapping[str, Any],
) -> Mapping[str, Any]:
return CodeExecutor.execute_workflow_code_template(
language=language,
code=code,
inputs=inputs,
)
def is_execution_error(self, error: Exception) -> bool:
return isinstance(error, CodeExecutionError)
@final
class DifyNodeFactory(NodeFactory):
"""
@ -45,23 +105,11 @@ class DifyNodeFactory(NodeFactory):
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
template_transform_max_output_length: int | None = None,
http_request_http_client: HttpClientProtocol | None = None,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol | None = None,
document_extractor_unstructured_api_config: UnstructuredApiConfig | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
)
self._code_limits = code_limits or CodeNodeLimits(
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
self._code_limits = CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
@ -71,21 +119,27 @@ class DifyNodeFactory(NodeFactory):
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = (
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
)
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
self._template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
self._http_request_file_manager = file_manager
self._rag_retrieval = DatasetRetrieval()
self._document_extractor_unstructured_api_config = (
document_extractor_unstructured_api_config
or UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
)
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
)
self._http_request_config = build_http_request_config(
max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
max_write_timeout=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
max_binary_size=dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE,
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
@ -126,7 +180,6 @@ class DifyNodeFactory(NodeFactory):
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
)
@ -146,11 +199,45 @@ class DifyNodeFactory(NodeFactory):
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_request_config=self._http_request_config,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.KNOWLEDGE_INDEX:
return KnowledgeIndexNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
index_processor=IndexProcessor(),
summary_index_service=SummaryIndex(),
)
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return LLMNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.DATASOURCE:
return DatasourceNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
datasource_manager=DatasourceManager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
@ -169,9 +256,95 @@ class DifyNodeFactory(NodeFactory):
unstructured_api_config=self._document_extractor_unstructured_api_config,
)
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
return node_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
node_data_model = ModelConfig.model_validate(node_data["model"])
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.stop = tuple(stop)
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance
def _build_memory_for_llm_node(
self,
*,
node_data: Mapping[str, Any],
model_instance: ModelInstance,
) -> PromptMessageMemory | None:
raw_memory_config = node_data.get("memory")
if raw_memory_config is None:
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
conversation_id = (
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
)
return fetch_memory(
conversation_id=conversation_id,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,
)

View File

@ -1,16 +1,39 @@
import logging
from collections.abc import Generator
from threading import Lock
from typing import Any, cast
from sqlalchemy import select
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.db.session_factory import session_factory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.file import File
from core.workflow.file.enums import FileTransferMethod, FileType
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@ -103,3 +126,238 @@ class DatasourceManager:
tenant_id,
datasource_type,
).get_datasource(datasource_name)
@classmethod
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
datasource_runtime = cls.get_datasource_runtime(
provider_id=provider_id,
datasource_name=datasource_name,
tenant_id=tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
return datasource_runtime.get_icon_url(tenant_id)
@classmethod
def stream_online_results(
cls,
*,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[DatasourceMessage, None, Any]:
"""
Pull-based streaming of domain messages from datasource plugins.
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
"""
ds_type = DatasourceProviderType.value_of(datasource_type)
runtime = cls.get_datasource_runtime(
provider_id=provider_id,
datasource_name=datasource_name,
tenant_id=tenant_id,
datasource_type=ds_type,
)
dsp_service = DatasourceProviderService()
credentials = dsp_service.get_datasource_credentials(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
credential_id=credential_id,
)
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
if credentials:
doc_runtime.runtime.credentials = credentials
if datasource_param is None:
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
user_id=user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_param.workspace_id,
page_id=datasource_param.page_id,
type=datasource_param.type,
),
provider_type=ds_type,
)
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
if credentials:
drive_runtime.runtime.credentials = credentials
if online_drive_request is None:
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
inner_gen = drive_runtime.online_drive_download_file(
user_id=user_id,
request=OnlineDriveDownloadFileRequest(
id=online_drive_request.id,
bucket=online_drive_request.bucket,
),
provider_type=ds_type,
)
else:
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
# Bridge through to caller while preserving generator return contract
yield from inner_gen
# No structured final data here; node/adapter will assemble outputs
return {}
@classmethod
def stream_node_events(
cls,
*,
node_id: str,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: Any,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
ds_type = DatasourceProviderType.value_of(datasource_type)
messages = cls.stream_online_results(
user_id=user_id,
datasource_name=datasource_name,
datasource_type=datasource_type,
provider_id=provider_id,
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
credential_id=credential_id,
datasource_param=datasource_param,
online_drive_request=online_drive_request,
)
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
)
variables: dict[str, Any] = {}
file_out: File | None = None
for message in transformed:
mtype = message.type
if mtype in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
wanted_ds_type = ds_type in {
DatasourceProviderType.ONLINE_DRIVE,
DatasourceProviderType.ONLINE_DOCUMENT,
}
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
url = message.message.text
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with session_factory.create_session() as session:
stmt = select(ToolFile).where(
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
)
datasource_file = session.scalar(stmt)
if not datasource_file:
raise ValueError(
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
)
mime_type = datasource_file.mimetype
if datasource_file is not None:
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(mime_type),
"transfer_method": FileTransferMethod.TOOL_FILE,
"url": url,
}
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
elif mtype == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
elif mtype == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
)
elif mtype == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
name = message.message.variable_name
value = message.message.variable_value
if message.message.stream:
assert isinstance(value, str), "stream variable_value must be str"
variables[name] = variables.get(name, "") + value
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
else:
variables[name] = value
elif mtype == DatasourceMessage.MessageType.FILE:
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
f = message.meta.get("file")
if isinstance(f, File):
file_out = f
else:
pass
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
variable_pool.add([node_id, "file"], file_out)
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={**variables},
)
)
else:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file_out,
"datasource_type": ds_type,
},
)
)
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session:
upload_file = (
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
)
if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,
)
return file_info

View File

@ -379,4 +379,11 @@ class OnlineDriveDownloadFileRequest(BaseModel):
"""
id: str = Field(..., description="The id of the file")
bucket: str | None = Field(None, description="The name of the bucket")
bucket: str = Field("", description="The name of the bucket")
@field_validator("bucket", mode="before")
@classmethod
def _coerce_bucket(cls, v) -> str:
if v is None:
return ""
return str(v)

View File

@ -1,6 +1,5 @@
import logging
from collections.abc import Mapping
from enum import StrEnum
from threading import Lock
from typing import Any
@ -14,6 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
from core.helper.code_executor.template_transformer import TemplateTransformer
from core.helper.http_client_pooling import get_pooled_http_client
from core.workflow.nodes.code.entities import CodeLanguage
logger = logging.getLogger(__name__)
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
@ -40,12 +40,6 @@ class CodeExecutionResponse(BaseModel):
data: Data
class CodeLanguage(StrEnum):
PYTHON3 = "python3"
JINJA2 = "jinja2"
JAVASCRIPT = "javascript"
def _build_code_executor_client() -> httpx.Client:
return httpx.Client(
verify=CODE_EXECUTION_SSL_VERIFY,

View File

@ -5,7 +5,7 @@ from base64 import b64encode
from collections.abc import Mapping
from typing import Any
from core.variables.utils import dumps_with_segments
from core.workflow.variables.utils import dumps_with_segments
class TemplateTransformer(ABC):

View File

@ -4,10 +4,10 @@ from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types as mcp_types
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService

View File

@ -1,5 +1,5 @@
import logging
from collections.abc import Callable, Generator, Iterable, Sequence
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload
from configs import dify_config
@ -35,9 +35,12 @@ class ModelInstance:
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
self.provider_model_bundle = provider_model_bundle
self.model = model
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
self.model_type_instance = self.provider_model_bundle.model_type_instance
self.load_balancing_manager = self._get_load_balancing_manager(
configuration=provider_model_bundle.configuration,
@ -163,7 +166,7 @@ class ModelInstance:
Union[LLMResult, Generator],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
@ -191,7 +194,7 @@ class ModelInstance:
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
@ -215,7 +218,7 @@ class ModelInstance:
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
texts=texts,
user=user,
@ -243,7 +246,7 @@ class ModelInstance:
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
@ -264,7 +267,7 @@ class ModelInstance:
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
model=self.model_name,
credentials=self.credentials,
texts=texts,
),
@ -294,7 +297,7 @@ class ModelInstance:
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
@ -328,7 +331,7 @@ class ModelInstance:
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
@ -352,7 +355,7 @@ class ModelInstance:
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
text=text,
user=user,
@ -373,7 +376,7 @@ class ModelInstance:
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
file=file,
user=user,
@ -396,7 +399,7 @@ class ModelInstance:
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
user=user,
@ -469,7 +472,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
model=self.model_name, credentials=self.credentials, language=language
)

View File

@ -0,0 +1,3 @@
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
from core.model_runtime.entities import PromptMessage
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages."""
def get_history_prompt_messages(
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@ -83,19 +83,21 @@ def _merge_tool_call_delta(
tool_call.function.arguments += delta.function.arguments
def _build_llm_result_from_first_chunk(
def _build_llm_result_from_chunks(
model: str,
prompt_messages: Sequence[PromptMessage],
chunks: Iterator[LLMResultChunk],
) -> LLMResult:
"""
Build a single `LLMResult` from the first returned chunk.
Build a single `LLMResult` by accumulating all returned chunks.
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
Some models only support streaming output (e.g. Qwen3 open-source edition)
and the plugin side may still implement the response via a chunked stream,
so all chunks must be consumed and concatenated into a single ``LLMResult``.
Note:
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
The ``usage`` is taken from the last chunk that carries it, which is the
typical convention for streaming responses (the final chunk contains the
aggregated token counts).
"""
content = ""
content_list: list[PromptMessageContentUnionTypes] = []
@ -104,24 +106,27 @@ def _build_llm_result_from_first_chunk(
tools_calls: list[AssistantPromptMessage.ToolCall] = []
try:
first_chunk = next(chunks, None)
if first_chunk is not None:
if isinstance(first_chunk.delta.message.content, str):
content += first_chunk.delta.message.content
elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
for chunk in chunks:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
if chunk.delta.message.tool_calls:
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint
if chunk.delta.usage:
usage = chunk.delta.usage
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception:
logger.exception("Error while consuming non-stream plugin chunk iterator.")
raise
finally:
try:
for _ in chunks:
pass
except Exception:
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
# Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections).
close = getattr(chunks, "close", None)
if callable(close):
close()
return LLMResult(
model=model,
@ -174,7 +179,7 @@ def _normalize_non_stream_plugin_result(
) -> LLMResult:
if isinstance(result, LLMResult):
return result
return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result)
return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result)
def _increase_tool_call(

View File

@ -14,6 +14,7 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
from core.ops.aliyun_trace.entities.semconv import (
DIFY_APP_ID,
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
GEN_AI_OUTPUT_MESSAGE,
@ -99,6 +100,16 @@ class AliyunDataTrace(BaseTraceInstance):
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
raise ValueError(f"Aliyun get project url failed: {str(e)}")
def _extract_app_id(self, trace_info: BaseTraceInfo) -> str:
"""Extract app_id from trace_info, trying metadata first then message_data."""
app_id = trace_info.metadata.get("app_id")
if app_id:
return str(app_id)
message_data = getattr(trace_info, "message_data", None)
if message_data is not None:
return str(getattr(message_data, "app_id", ""))
return ""
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
@ -143,13 +154,16 @@ class AliyunDataTrace(BaseTraceInstance):
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_str,
),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_str,
),
DIFY_APP_ID: self._extract_app_id(trace_info),
},
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
@ -441,6 +455,8 @@ class AliyunDataTrace(BaseTraceInstance):
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
app_id = self._extract_app_id(trace_info)
if message_span_id:
message_span = SpanData(
trace_id=trace_metadata.trace_id,
@ -449,13 +465,16 @@ class AliyunDataTrace(BaseTraceInstance):
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
outputs=outputs_json,
),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
outputs=outputs_json,
),
DIFY_APP_ID: app_id,
},
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
@ -469,13 +488,16 @@ class AliyunDataTrace(BaseTraceInstance):
name="workflow",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_json,
),
attributes={
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_json,
),
**({DIFY_APP_ID: app_id} if message_span_id is None else {}),
},
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,

View File

@ -3,6 +3,9 @@ from typing import Final
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
# Dify-specific attributes
DIFY_APP_ID: Final[str] = "dify.app_id"
# Public attributes
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"

View File

@ -155,6 +155,26 @@ def wrap_span_metadata(metadata, **kwargs):
return metadata
# Mapping from NodeType string values to OpenInference span kinds.
# NodeType values not listed here default to CHAIN.
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
"llm": OpenInferenceSpanKindValues.LLM,
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
"tool": OpenInferenceSpanKindValues.TOOL,
"agent": OpenInferenceSpanKindValues.AGENT,
}
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
"""Return the OpenInference span kind for a given workflow node type.
Covers every ``NodeType`` enum value. Nodes that do not have a
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
"""
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
@ -289,9 +309,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
)
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN
span_kind = _get_node_span_kind(node_execution.node_type)
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
@ -306,12 +325,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL
else:
span_kind = OpenInferenceSpanKindValues.CHAIN
workflow_span_context = set_span_in_context(workflow_span)
node_span = self.tracer.start_span(

View File

@ -14,10 +14,9 @@ class BaseTraceInstance(ABC):
Base trace instance for ops trace services
"""
@abstractmethod
def __init__(self, trace_config: BaseTracingConfig):
"""
Abstract initializer for the trace instance.
Initializer for the trace instance.
Distribute trace tasks by matching entities
"""
self.trace_config = trace_config

View File

@ -129,11 +129,11 @@ class LangfuseSpan(BaseModel):
default=None,
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
)
start_time: datetime | str | None = Field(
start_time: datetime | None = Field(
default_factory=datetime.now,
description="The time at which the span started, defaults to the current time.",
)
end_time: datetime | str | None = Field(
end_time: datetime | None = Field(
default=None,
description="The time at which the span ended. Automatically set by span.end().",
)
@ -146,7 +146,7 @@ class LangfuseSpan(BaseModel):
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
"via the API.",
)
level: str | None = Field(
level: LevelEnum | None = Field(
default=None,
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
"traces with elevated error levels and for highlighting in the UI.",
@ -222,16 +222,16 @@ class LangfuseGeneration(BaseModel):
default=None,
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
)
start_time: datetime | str | None = Field(
start_time: datetime | None = Field(
default_factory=datetime.now,
description="The time at which the generation started, defaults to the current time.",
)
completion_start_time: datetime | str | None = Field(
completion_start_time: datetime | None = Field(
default=None,
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
"down into time until completion started and completion duration.",
)
end_time: datetime | str | None = Field(
end_time: datetime | None = Field(
default=None,
description="The time at which the generation ended. Automatically set by generation.end().",
)

View File

@ -41,8 +41,8 @@ logger = logging.getLogger(__name__)
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
def __getitem__(self, key: str) -> dict[str, Any]:
match key:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
@ -149,7 +149,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
raise KeyError(f"Unsupported tracing provider: {key}")
provider_config_map = OpsTraceProviderConfigMap()

View File

@ -1,6 +1,6 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Union
from typing import Any, Union
from urllib.parse import urlparse
from sqlalchemy import select
@ -9,7 +9,7 @@ from models.engine import db
from models.model import Message
def filter_none_values(data: dict):
def filter_none_values(data: dict[str, Any]) -> dict[str, Any]:
new_data = {}
for key, value in data.items():
if value is None:

View File

@ -2,6 +2,7 @@ import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.app.llm import deduct_llm_quota
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import (
@ -29,7 +30,6 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant
@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
@ -120,26 +118,37 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
user=user_id,
)
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
if isinstance(response, Generator):
def handle_non_streaming(
response: LLMResultWithStructuredOutput,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
yield LLMResultChunkWithStructuredOutput(
model=response.model,
prompt_messages=[],
system_fingerprint=response.system_fingerprint,
structured_output=response.structured_output,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle_non_streaming(response)
return handle()
else:
if response.usage:
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
yield LLMResultChunkWithStructuredOutput(
model=response.model,
prompt_messages=[],
system_fingerprint=response.system_fingerprint,
structured_output=response.structured_output,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle_non_streaming(response)
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):

View File

@ -4,6 +4,7 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
@ -44,7 +45,8 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
prompt_messages = []
@ -59,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
@ -71,6 +74,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
@ -85,7 +89,8 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
@ -111,6 +116,7 @@ class AdvancedPromptTransform(PromptTransform):
parser=parser,
prompt_inputs=prompt_inputs,
model_config=model_config,
model_instance=model_instance,
)
if query:
@ -146,7 +152,8 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
@ -198,8 +205,13 @@ class AdvancedPromptTransform(PromptTransform):
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
prompt_messages = self._append_chat_histories(
memory,
memory_config,
prompt_messages,
model_config=model_config,
model_instance=model_instance,
)
if files and query is not None:
for file in files:
prompt_message_contents.append(
@ -276,7 +288,8 @@ class AdvancedPromptTransform(PromptTransform):
role_prefix: MemoryConfig.RolePrefix,
parser: PromptTemplateParser,
prompt_inputs: Mapping[str, str],
model_config: ModelConfigWithCredentialsEntity,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> Mapping[str, str]:
prompt_inputs = dict(prompt_inputs)
if "#histories#" in parser.variable_keys:
@ -286,7 +299,11 @@ class AdvancedPromptTransform(PromptTransform):
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
rest_tokens = self._calculate_rest_token(
[tmp_human_message],
model_config=model_config,
model_instance=model_instance,
)
histories = self._get_history_messages_from_memory(
memory=memory,

View File

@ -41,13 +41,15 @@ class AgentHistoryPromptTransform(PromptTransform):
if not self.memory:
return prompt_messages
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
max_token_limit = self._calculate_rest_token(self.prompt_messages, model_config=self.model_config)
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
self.model_config.model,
self.model_config.credentials,
self.history_messages,
)
if curr_message_tokens <= max_token_limit:
return self.history_messages
@ -63,7 +65,9 @@ class AgentHistoryPromptTransform(PromptTransform):
# a message is start with UserPromptMessage
if isinstance(prompt_message, UserPromptMessage):
curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
self.model_config.model,
self.model_config.credentials,
prompt_messages,
)
# if current message token is overflow, drop all the prompts in current message and break
if curr_message_tokens > max_token_limit:

View File

@ -4,45 +4,83 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform:
def _resolve_model_runtime(
self,
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> tuple[ModelInstance, AIModelEntity]:
if model_instance is None:
if model_config is None:
raise ValueError("Either model_config or model_instance must be provided.")
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_instance.credentials = model_config.credentials
model_instance.parameters = model_config.parameters
model_instance.stop = model_config.stop
model_schema = model_instance.model_type_instance.get_model_schema(
model=model_instance.model_name,
credentials=model_instance.credentials,
)
if model_schema is None:
if model_config is None:
raise ValueError("Model schema not found for the provided model instance.")
model_schema = model_config.model_schema
return model_instance, model_schema
def _append_chat_histories(
self,
memory: BaseMemory,
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity,
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
rest_tokens = self._calculate_rest_token(
prompt_messages,
model_config=model_config,
model_instance=model_instance,
)
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
prompt_messages.extend(histories)
return prompt_messages
def _calculate_rest_token(
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
self,
prompt_messages: list[PromptMessage],
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> int:
model_instance, model_schema = self._resolve_model_runtime(
model_config=model_config,
model_instance=model_instance,
)
model_parameters = model_instance.parameters
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
model_parameters.get(parameter_rule.name)
or model_parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens

View File

@ -252,7 +252,7 @@ class SimplePromptTransform(PromptTransform):
if memory:
tmp_human_message = UserPromptMessage(content=prompt)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config=model_config)
histories = self._get_history_messages_from_memory(
memory=memory,
memory_config=MemoryConfig(

View File

@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
metrics=self.config.metrics,
include_values=True,
vector=None, # ty: ignore [invalid-argument-type]
content=None, # ty: ignore [invalid-argument-type]
vector=None,
content=None,
top_k=1,
filter=f"ref_doc_id='{id}'",
)
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data=None,
collection_data_filter=f"ref_doc_id IN {ids_str}",
)
self._client.delete_collection_data(request)
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
collection_data=None, # ty: ignore [invalid-argument-type]
collection_data=None,
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
)
self._client.delete_collection_data(request)
@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=query_vector,
content=None, # ty: ignore [invalid-argument-type]
content=None,
top_k=kwargs.get("top_k", 4),
filter=where_clause,
)
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
collection=self._collection_name,
include_values=kwargs.pop("include_values", True),
metrics=self.config.metrics,
vector=None, # ty: ignore [invalid-argument-type]
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
filter=where_clause,

View File

@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
)

View File

@ -19,7 +19,7 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
raise NotImplementedError
@abstractmethod
@ -27,14 +27,14 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str):
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
@ -46,7 +46,7 @@ class BaseVector(ABC):
raise NotImplementedError
@abstractmethod
def delete(self):
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:

View File

@ -35,7 +35,9 @@ class CacheEmbedding(Embeddings):
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
model_name=self._model_instance.model_name,
hash=hash,
provider_name=self._model_instance.provider,
)
.first()
)
@ -52,7 +54,7 @@ class CacheEmbedding(Embeddings):
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
self._model_instance.model_name, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
@ -87,7 +89,7 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
model_name=self._model_instance.model_name,
hash=hash,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
@ -114,7 +116,9 @@ class CacheEmbedding(Embeddings):
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
model_name=self._model_instance.model_name,
hash=file_id,
provider_name=self._model_instance.provider,
)
.first()
)
@ -131,7 +135,7 @@ class CacheEmbedding(Embeddings):
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
self._model_instance.model_name, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
@ -168,7 +172,7 @@ class CacheEmbedding(Embeddings):
file_id = multimodel_documents[i]["file_id"]
if file_id not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
model_name=self._model_instance.model_name,
hash=file_id,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
@ -190,7 +194,7 @@ class CacheEmbedding(Embeddings):
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
@ -233,7 +237,7 @@ class CacheEmbedding(Embeddings):
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)

View File

@ -0,0 +1,252 @@
import concurrent.futures
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory
from .processor.paragraph_index_processor import ParagraphIndexProcessor
logger = logging.getLogger(__name__)
class IndexProcessor:
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview = index_processor.format_preview(chunks)
data = Preview(
chunk_structure=preview["chunk_structure"],
total_segments=preview["total_segments"],
preview=[],
parent_mode=None,
qa_preview=[],
)
if "parent_mode" in preview:
data.parent_mode = preview["parent_mode"]
for item in preview["preview"]:
if "content" in item and "child_chunks" in item:
data.preview.append(
PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None)
)
elif "question" in item and "answer" in item:
data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"]))
elif "content" in item:
data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None))
return data
def index_and_clean(
self,
dataset_id: str,
document_id: str,
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
summary_index_setting: dict | None = None,
):
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if not document:
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
dataset_name_value = dataset.name
document_name_value = document.name
created_at_value = document.created_at
if summary_index_setting is None:
summary_index_setting = dataset.summary_index_setting
index_node_ids = []
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
if original_document_id:
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
indexing_start_at = time.perf_counter()
# delete from vector index
if index_node_ids:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
with session_factory.create_session() as session, session.begin():
if index_node_ids:
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
session.execute(segment_delete_stmt)
index_processor.index(dataset, document, chunks)
indexing_end_at = time.perf_counter()
with session_factory.create_session() as session, session.begin():
document.indexing_latency = indexing_end_at - indexing_start_at
document.indexing_status = "completed"
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
session.query(func.sum(DocumentSegment.word_count))
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
)
.scalar()
) or 0
# Update need_summary based on dataset's summary_index_setting
if summary_index_setting and summary_index_setting.get("enable") is True:
document.need_summary = True
else:
document.need_summary = False
session.add(document)
# update document segment status
session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
).update(
{
DocumentSegment.status: "completed",
DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}
)
return {
"dataset_id": dataset_id,
"dataset_name": dataset_name_value,
"batch": batch,
"document_id": document_id,
"document_name": document_name_value,
"created_at": created_at_value.timestamp(),
"display_status": "completed",
}
def get_preview_output(
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
) -> Preview:
doc_language = None
with session_factory.create_session() as session:
if document_id:
document = session.query(Document).filter_by(id=document_id).first()
else:
document = None
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
if summary_index_setting is None:
summary_index_setting = dataset.summary_index_setting
if document:
doc_language = document.doc_language
indexing_technique = dataset.indexing_technique
tenant_id = dataset.tenant_id
preview_output = self.format_preview(chunk_structure, chunks)
if indexing_technique != "high_quality":
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):
return preview_output
if preview_output.preview is not None:
chunk_count = len(preview_output.preview)
logger.info(
"Generating summaries for %s chunks in preview mode (dataset: %s)",
chunk_count,
dataset_id,
)
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def generate_summary_for_chunk(preview_item: PreviewItem) -> None:
"""Generate summary for a single chunk."""
if flask_app:
with flask_app.app_context():
if preview_item.content is not None:
# Set Flask application context in worker thread
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview_item.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item.summary = summary
else:
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview_item.content if preview_item.content is not None else "",
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_output.preview))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output.preview))) as executor:
futures = [
executor.submit(generate_summary_for_chunk, preview_item) for preview_item in preview_output.preview
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode, if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise KnowledgeIndexNodeError(error_summary)
completed_count = sum(1 for item in preview_output.preview if item.summary is not None)
logger.info(
"Completed summary generation for preview chunks: %s/%s succeeded",
completed_count,
len(preview_output.preview),
)
return preview_output

View File

@ -75,15 +75,15 @@ class BaseIndexProcessor(ABC):
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
) -> None:
raise NotImplementedError
@abstractmethod
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
raise NotImplementedError
@abstractmethod
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
raise NotImplementedError
@abstractmethod

View File

@ -8,6 +8,7 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper
@ -115,7 +115,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
) -> None:
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
@ -130,7 +130,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
else:
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
@ -196,7 +196,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
documents: list[Any] = []
all_multimodal_documents: list[Any] = []
if isinstance(chunks, list):
@ -469,12 +469,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if not isinstance(result, LLMResult):
raise ValueError("Expected LLMResult when stream=False")
summary_content = getattr(result.message, "content", "")
summary_content = result.message.get_text_content()
usage = result.usage
# Deduct quota for summary generation (same as workflow nodes)
try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))

View File

@ -126,7 +126,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
) -> None:
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
@ -139,7 +139,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
# node_ids is segment's node_ids
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
@ -272,7 +272,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_nodes.append(child_document)
return child_nodes
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
parent_childs = ParentChildStructureChunk.model_validate(chunks)
documents = []
for parent_child in parent_childs.parent_child_chunks:

View File

@ -139,14 +139,14 @@ class QAIndexProcessor(BaseIndexProcessor):
multimodal_documents: list[AttachmentDocument] | None = None,
with_keywords: bool = True,
**kwargs,
):
) -> None:
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if multimodal_documents and dataset.is_multimodal:
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
@ -206,7 +206,7 @@ class QAIndexProcessor(BaseIndexProcessor):
docs.append(doc)
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
qa_chunks = QAStructureChunk.model_validate(chunks)
documents = []
for qa_chunk in qa_chunks.qa_chunks:

View File

@ -38,7 +38,7 @@ class RerankModelRunner(BaseRerankRunner):
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
model=self.rerank_model_instance.model,
model=self.rerank_model_instance.model_name,
model_type=ModelType.RERANK,
)
if not is_support_vision:

View File

@ -248,19 +248,22 @@ class DatasetRetrieval:
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
ext_meta = item.metadata or {}
title = ext_meta.get("title") or ""
doc_id = ext_meta.get("document_id") or title
source = Source(
metadata=SourceMetadata(
source="knowledge",
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id"),
document_name=item.metadata.get("title"),
dataset_id=ext_meta.get("dataset_id") or "",
dataset_name=ext_meta.get("dataset_name") or "",
document_id=str(doc_id),
document_name=ext_meta.get("title") or "",
data_source_type="external",
retriever_from="workflow",
score=item.metadata.get("score"),
doc_metadata=item.metadata,
score=float(ext_meta.get("score") or 0.0),
doc_metadata=ext_meta,
),
title=item.metadata.get("title"),
title=title,
content=item.page_content,
)
retrieval_resource_list.append(source)

View File

@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage

View File

View File

@ -0,0 +1,86 @@
import concurrent.futures
import logging
from core.db.session_factory import session_factory
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
class SummaryIndex:
def generate_and_vectorize_summary(
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
) -> None:
if is_preview:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset or dataset.indexing_technique != "high_quality":
return
if summary_index_setting is None:
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
return
if not document_id:
return
document = session.query(Document).filter_by(id=document_id).first()
# Skip qa_model documents
if document is None or document.doc_form == "qa_model":
return
query = session.query(DocumentSegment).filter_by(
dataset_id=dataset_id,
document_id=document_id,
status="completed",
enabled=True,
)
segments = query.all()
segment_ids = [segment.id for segment in segments]
if not segment_ids:
return
existing_summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.status == "completed",
)
.all()
)
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
# Preview mode should process segments that are MISSING completed summaries
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
# If all segments already have completed summaries, nothing to do in preview mode
if not pending_segment_ids:
return
max_workers = min(10, len(pending_segment_ids))
def process_segment(segment_id: str) -> None:
"""Process a single segment in a thread with a fresh DB session."""
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
if segment is None:
return
try:
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
except Exception:
logger.exception(
"Failed to generate summary for segment %s",
segment_id,
)
# Continue processing other segments
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_segment, segment_id) for segment_id in pending_segment_ids]
concurrent.futures.wait(futures)
else:
generate_summary_index_task.delay(dataset_id, document_id, None)

View File

@ -6,9 +6,9 @@ identity:
zh_Hans: 网页抓取
pt_BR: WebScraper
description:
en_US: Web Scrapper tool kit is used to scrape web
en_US: Web Scraper tool kit is used to scrape web
zh_Hans: 一个用于抓取网页的工具。
pt_BR: Web Scrapper tool kit is used to scrape web
pt_BR: Web Scraper tool kit is used to scrape web
icon: icon.svg
tags:
- productivity

View File

@ -47,7 +47,7 @@ class ModelInvocationUtils:
raise InvokeModelError("Model not found")
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
if not schema:
raise InvokeModelError("No model schema found")

View File

@ -2,7 +2,7 @@ import re
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Any
from typing import Any, TypedDict
import httpx
from flask import request
@ -14,6 +14,12 @@ from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParamet
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class InterfaceDict(TypedDict):
path: str
method: str
operation: dict[str, Any]
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
@ -35,7 +41,7 @@ class ApiBasedToolSchemaParser:
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
interfaces: list[InterfaceDict] = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:

View File

@ -1,11 +1,11 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.variables.input_entities import VariableEntity
class WorkflowToolConfigurationUtils:

View File

@ -5,7 +5,6 @@ from collections.abc import Mapping
from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
@ -23,6 +22,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode

View File

@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.variables import VariableBase
from core.workflow.variables import VariableBase
class ConversationVariableUpdater(Protocol):

View File

@ -7,12 +7,28 @@ Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, Any, final
from contextlib import AbstractContextManager
from typing import Any, Protocol, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
class RedisPipelineProtocol(Protocol):
"""Minimal Redis pipeline contract used by the command channel."""
def lrange(self, name: str, start: int, end: int) -> Any: ...
def delete(self, *names: str) -> Any: ...
def execute(self) -> list[Any]: ...
def rpush(self, name: str, *values: str) -> Any: ...
def expire(self, name: str, time: int) -> Any: ...
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
def get(self, name: str) -> Any: ...
class RedisClientProtocol(Protocol):
"""Redis client contract required by the command channel."""
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
@final
@ -26,7 +42,7 @@ class RedisChannel:
def __init__(
self,
redis_client: "RedisClientWrapper",
redis_client: RedisClientProtocol,
channel_key: str,
command_ttl: int = 3600,
) -> None:

View File

@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import Variable
from core.workflow.variables.variables import Variable
class CommandType(StrEnum):

View File

@ -9,7 +9,6 @@ from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@ -77,13 +76,10 @@ class GraphEngine:
config: GraphEngineConfig = _DEFAULT_CONFIG,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
self._config = config
@ -163,7 +159,6 @@ class GraphEngine:
layers=self._layers,
execution_context=execution_context,
config=self._config,
stop_event=self._stop_event,
)
# === Orchestration ===
@ -194,7 +189,6 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@ -314,7 +308,6 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
@ -348,7 +341,6 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it

View File

@ -3,13 +3,14 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Callers must provide a Redis client dependency from outside the workflow package.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
@ -17,7 +18,6 @@ from core.workflow.graph_engine.entities.commands import (
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -31,8 +31,12 @@ class GraphEngineManager:
by sending commands through Redis channels, without user validation.
"""
@staticmethod
def send_stop_command(task_id: str, reason: str | None = None) -> None:
_redis_client: RedisClientProtocol
def __init__(self, redis_client: RedisClientProtocol) -> None:
self._redis_client = redis_client
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
"""
Send a stop command to a running workflow.
@ -41,34 +45,31 @@ class GraphEngineManager:
reason: Optional reason for stopping (defaults to "User requested stop")
"""
abort_command = AbortCommand(reason=reason or "User requested stop")
GraphEngineManager._send_command(task_id, abort_command)
self._send_command(task_id, abort_command)
@staticmethod
def send_pause_command(task_id: str, reason: str | None = None) -> None:
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
"""Send a pause command to a running workflow."""
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
self._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
self._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
if not task_id:
return
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(redis_client, channel_key)
channel = RedisChannel(self._redis_client, channel_key)
try:
channel.send_command(command)

View File

@ -44,7 +44,6 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@ -62,7 +61,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = stop_event
self._stop_event = threading.Event()
self._start_time: float | None = None
def start(self) -> None:
@ -70,12 +69,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=2.0)

View File

@ -42,7 +42,6 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
execution_context: IExecutionContext | None = None,
) -> None:
@ -63,16 +62,13 @@ class Worker(threading.Thread):
self._graph = graph
self._worker_id = worker_id
self._execution_context = execution_context
self._stop_event = stop_event
self._stop_event = threading.Event()
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
"""Signal the worker to stop processing."""
self._stop_event.set()
@property
def is_idle(self) -> bool:

View File

@ -37,7 +37,6 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
config: GraphEngineConfig,
execution_context: IExecutionContext | None = None,
) -> None:
@ -64,7 +63,6 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@ -135,7 +133,6 @@ class WorkerPool:
layers=self._layers,
worker_id=worker_id,
execution_context=self._execution_context,
stop_event=self._stop_event,
)
worker.start()

View File

@ -34,7 +34,6 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@ -53,6 +52,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayFileSegment, StringSegment
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy

View File

@ -1,13 +1,13 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.variables import ArrayFileSegment, FileSegment, Segment
class AnswerNode(Node[AnswerNodeData]):

View File

@ -305,10 +305,6 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
"""
Find all extractor node configurations that have parent_node_id == self._node_id.
@ -338,7 +334,6 @@ class Node(Generic[NodeDataT]):
if not extractor_configs:
return
# Use DifyNodeFactory to properly instantiate nodes with required dependencies
node_factory = DifyNodeFactory(
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
@ -352,23 +347,20 @@ class Node(Generic[NodeDataT]):
try:
nested_node = node_factory.create_node(config)
except ValueError:
# Skip nodes that cannot be created (e.g., unknown type)
continue
# Execute and process nested node events
for event in nested_node.run():
# Tag event with parent node id for stream ordering and history tracking
if isinstance(event, GraphNodeEventBase):
event.in_parent_node_id = self._node_id
if isinstance(event, NodeRunSucceededEvent):
# Store nested node outputs in variable pool
outputs: Mapping[str, Any] = event.node_run_result.outputs
for variable_name, variable_value in outputs.items():
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
if not isinstance(event, NodeRunStreamChunkEvent):
yield event
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@ -440,21 +432,6 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(

View File

@ -1,18 +1,15 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import TYPE_CHECKING, Any, ClassVar, cast
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Protocol, cast
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.variables.segments import ArrayFileSegment
from core.workflow.variables.types import SegmentType
from .exc import (
CodeNodeError,
@ -25,12 +22,56 @@ if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
class WorkflowCodeExecutor(Protocol):
def execute(
self,
*,
language: CodeLanguage,
code: str,
inputs: Mapping[str, Any],
) -> Mapping[str, Any]: ...
def is_execution_error(self, error: Exception) -> bool: ...
def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]:
return {
"type": "code",
"config": {
"variables": [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
],
"code_language": language,
"code": code,
"outputs": {"result": {"type": "string", "children": None}},
},
}
_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = {
CodeLanguage.PYTHON3: dedent(
"""
def main(arg1: str, arg2: str):
return {
"result": arg1 + arg2,
}
"""
),
CodeLanguage.JAVASCRIPT: dedent(
"""
function main({arg1, arg2}) {
return {
result: arg1 + arg2
}
}
"""
),
}
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
Python3CodeProvider,
JavascriptCodeProvider,
)
_limits: CodeNodeLimits
def __init__(
@ -40,8 +81,7 @@ class CodeNode(Node[CodeNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_executor: WorkflowCodeExecutor,
code_limits: CodeNodeLimits,
) -> None:
super().__init__(
@ -50,10 +90,7 @@ class CodeNode(Node[CodeNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
)
self._code_executor: WorkflowCodeExecutor = code_executor
self._limits = code_limits
@classmethod
@ -67,15 +104,10 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
code_provider: type[CodeNodeProvider] = next(
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
)
return code_provider.get_default_config()
@classmethod
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
return cls._DEFAULT_CODE_PROVIDERS
default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language)
if default_code is None:
raise CodeNodeError(f"Unsupported code language: {code_language}")
return _build_default_config(language=code_language, code=default_code)
@classmethod
def version(cls) -> str:
@ -97,8 +129,7 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
_ = self._select_code_provider(code_language)
result = self._code_executor.execute_workflow_code_template(
result = self._code_executor.execute(
language=code_language,
code=code,
inputs=variables,
@ -106,19 +137,19 @@ class CodeNode(Node[CodeNodeData]):
# Transform result
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
except CodeNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
)
except Exception as e:
if not self._code_executor.is_execution_error(e):
raise
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
for provider in self._code_providers:
if provider.is_accept_language(code_language):
return provider
raise CodeNodeError(f"Unsupported code language: {code_language}")
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string

View File

@ -1,11 +1,18 @@
from enum import StrEnum
from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
from core.helper.code_executor.code_executor import CodeLanguage
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.variables.types import SegmentType
class CodeLanguage(StrEnum):
PYTHON3 = "python3"
JINJA2 = "jinja2"
JAVASCRIPT = "javascript"
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
[

View File

@ -1,40 +1,26 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceParameter,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.file import File
from core.workflow.file.enums import FileTransferMethod, FileType
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
from core.workflow.repositories.datasource_manager_protocol import (
DatasourceManagerProtocol,
DatasourceParameter,
OnlineDriveDownloadFileParam,
)
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError
from .exc import DatasourceNodeError
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class DatasourceNode(Node[DatasourceNodeData]):
@ -45,6 +31,22 @@ class DatasourceNode(Node[DatasourceNodeData]):
node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
datasource_manager: DatasourceManagerProtocol,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self.datasource_manager = datasource_manager
def _run(self) -> Generator:
"""
Run the datasource node
@ -52,84 +54,69 @@ class DatasourceNode(Node[DatasourceNodeData]):
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segement:
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segment:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segement:
datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None
datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segment:
raise DatasourceNodeError("Datasource info is not set")
datasource_info_value = datasource_info_segement.value
datasource_info_value = datasource_info_segment.value
if not isinstance(datasource_info_value, dict):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
# get datasource runtime
from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = DatasourceProviderType.value_of(datasource_type)
provider_id = f"{node_data.plugin_id}/{node_data.provider_name}"
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_info["icon"] = self.datasource_manager.get_icon_url(
provider_id=provider_id,
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=datasource_type,
datasource_type=datasource_type.value,
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
parameters_for_log = datasource_info
try:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=datasource_info.get("credential_id", ""),
)
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
if credentials:
datasource_runtime.runtime.credentials = credentials
online_document_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
),
provider_type=datasource_type,
case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE:
# Build typed request objects
datasource_parameters = None
if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_parameters = DatasourceParameter(
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
)
)
yield from self._transform_message(
messages=online_document_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
)
case DatasourceProviderType.ONLINE_DRIVE:
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
if credentials:
datasource_runtime.runtime.credentials = credentials
online_drive_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.online_drive_download_file(
user_id=self.user_id,
request=OnlineDriveDownloadFileRequest(
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket"),
),
provider_type=datasource_type,
online_drive_request = None
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
online_drive_request = OnlineDriveDownloadFileParam(
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket", ""),
)
)
yield from self._transform_datasource_file_message(
messages=online_drive_result,
credential_id = datasource_info.get("credential_id", "")
yield from self.datasource_manager.stream_node_events(
node_id=self._node_id,
user_id=self.user_id,
datasource_name=node_data.datasource_name or "",
datasource_type=datasource_type.value,
provider_id=provider_id,
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=credential_id,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
variable_pool=variable_pool,
datasource_type=datasource_type,
datasource_param=datasource_parameters,
online_drive_request=online_drive_request,
)
case DatasourceProviderType.WEBSITE_CRAWL:
yield StreamCompletedEvent(
@ -147,23 +134,9 @@ class DatasourceNode(Node[DatasourceNodeData]):
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,
file_info = self.datasource_manager.get_upload_file_by_id(
file_id=related_id, tenant_id=self.tenant_id
)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
@ -201,55 +174,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
)
)
def _generate_parameters(
self,
*,
datasource_parameters: Sequence[DatasourceParameter],
variable_pool: VariablePool,
node_data: DatasourceNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {}
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
datasource_input = node_data.datasource_parameters[parameter_name]
if datasource_input.type == "variable":
variable = variable_pool.get(datasource_input.value)
if variable is None:
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value
elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -287,206 +211,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
return result
def _transform_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict | list] = []
variables: dict[str, Any] = {}
for message in message_stream:
match message.type:
case (
DatasourceMessage.MessageType.IMAGE_LINK
| DatasourceMessage.MessageType.BINARY_LINK
| DatasourceMessage.MessageType.IMAGE
):
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
case DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
case DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
case DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
case DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
case DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
case DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case (
DatasourceMessage.MessageType.BLOB_CHUNK
| DatasourceMessage.MessageType.LOG
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
):
pass
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={**variables},
metadata={
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
},
inputs=parameters_for_log,
)
)
@classmethod
def version(cls) -> str:
return "1"
def _transform_datasource_file_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: VariablePool,
datasource_type: DatasourceProviderType,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
file = None
for message in message_stream:
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
if file:
variable_pool.add([self._node_id, "file"], file)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file,
"datasource_type": datasource_type,
},
)
)

View File

@ -21,12 +21,12 @@ from docx.table import Table
from docx.text.paragraph import Paragraph
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.file import File, FileTransferMethod, file_manager
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.variables import ArrayFileSegment
from core.workflow.variables.segments import ArrayStringSegment, FileSegment
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError

Some files were not shown because too many files have changed in this diff Show More