mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
Merge commit '9c339239' into sandboxed-agent-rebase
Made-with: Cursor # Conflicts: # api/README.md # api/controllers/console/app/workflow_draft_variable.py # api/core/agent/cot_agent_runner.py # api/core/agent/fc_agent_runner.py # api/core/app/apps/advanced_chat/app_runner.py # api/core/plugin/backwards_invocation/model.py # api/core/prompt/advanced_prompt_transform.py # api/core/workflow/nodes/base/node.py # api/core/workflow/nodes/llm/llm_utils.py # api/core/workflow/nodes/llm/node.py # api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py # api/core/workflow/nodes/question_classifier/question_classifier_node.py # api/core/workflow/runtime/graph_runtime_state.py # api/extensions/storage/base_storage.py # api/factories/variable_factory.py # api/pyproject.toml # api/services/variable_truncator.py # api/uv.lock # web/app/account/oauth/authorize/page.tsx # web/app/components/app/configuration/config-var/config-modal/field.tsx # web/app/components/base/alert.tsx # web/app/components/base/chat/chat/answer/human-input-content/executed-action.tsx # web/app/components/base/chat/chat/answer/more.tsx # web/app/components/base/chat/chat/answer/operation.tsx # web/app/components/base/chat/chat/answer/workflow-process.tsx # web/app/components/base/chat/chat/citation/index.tsx # web/app/components/base/chat/chat/citation/popup.tsx # web/app/components/base/chat/chat/citation/progress-tooltip.tsx # web/app/components/base/chat/chat/citation/tooltip.tsx # web/app/components/base/chat/chat/question.tsx # web/app/components/base/chat/embedded-chatbot/inputs-form/index.tsx # web/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown.tsx # web/app/components/base/markdown-blocks/form.tsx # web/app/components/base/prompt-editor/plugins/hitl-input-block/component-ui.tsx # web/app/components/base/tag-management/panel.tsx # web/app/components/base/tag-management/trigger.tsx # web/app/components/header/account-setting/index.tsx # web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx # web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx # web/app/signin/utils/post-login-redirect.ts # web/eslint-suppressions.json # web/package.json # web/pnpm-lock.yaml
This commit is contained in:
commit
cccff6768a
168
.agents/skills/backend-code-review/SKILL.md
Normal file
168
.agents/skills/backend-code-review/SKILL.md
Normal file
@ -0,0 +1,168 @@
|
||||
---
|
||||
name: backend-code-review
|
||||
description: Review backend code for quality, security, maintainability, and best practices based on established checklist rules. Use when the user requests a review, analysis, or improvement of backend files (e.g., `.py`) under the `api/` directory. Do NOT use for frontend files (e.g., `.tsx`, `.ts`, `.js`). Supports pending-change review, code snippets review, and file-focused review.
|
||||
---
|
||||
|
||||
# Backend Code Review
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use this skill whenever the user asks to **review, analyze, or improve** backend code (e.g., `.py`) under the `api/` directory. Supports the following review modes:
|
||||
|
||||
- **Pending-change review**: when the user asks to review current changes (inspect staged/working-tree files slated for commit to get the changes).
|
||||
- **Code snippets review**: when the user pastes code snippets (e.g., a function/class/module excerpt) into the chat and asks for a review.
|
||||
- **File-focused review**: when the user points to specific files and asks for a review of those files (one file or a small, explicit set of files, e.g., `api/...`, `api/app.py`).
|
||||
|
||||
Do NOT use this skill when:
|
||||
|
||||
- The request is about frontend code or UI (e.g., `.tsx`, `.ts`, `.js`, `web/`).
|
||||
- The user is not asking for a review/analysis/improvement of backend code.
|
||||
- The scope is not under `api/` (unless the user explicitly asks to review backend-related changes outside `api/`).
|
||||
|
||||
## How to use this skill
|
||||
|
||||
Follow these steps when using this skill:
|
||||
|
||||
1. **Identify the review mode** (pending-change vs snippet vs file-focused) based on the user’s input. Keep the scope tight: review only what the user provided or explicitly referenced.
|
||||
2. Follow the rules defined in **Checklist** to perform the review. If no Checklist rule matches, apply **General Review Rules** as a fallback to perform the best-effort review.
|
||||
3. Compose the final output strictly follow the **Required Output Format**.
|
||||
|
||||
Notes when using this skill:
|
||||
- Always include actionable fixes or suggestions (including possible code snippets).
|
||||
- Use best-effort `File:Line` references when a file path and line numbers are available; otherwise, use the most specific identifier you can.
|
||||
|
||||
## Checklist
|
||||
|
||||
- db schema design: if the review scope includes code/files under `api/models/` or `api/migrations/`, follow [references/db-schema-rule.md](references/db-schema-rule.md) to perform the review
|
||||
- architecture: if the review scope involves controller/service/core-domain/libs/model layering, dependency direction, or moving responsibilities across modules, follow [references/architecture-rule.md](references/architecture-rule.md) to perform the review
|
||||
- repositories abstraction: if the review scope contains table/model operations (e.g., `select(...)`, `session.execute(...)`, joins, CRUD) and is not under `api/repositories`, `api/core/repositories`, or `api/extensions/*/repositories/`, follow [references/repositories-rule.md](references/repositories-rule.md) to perform the review
|
||||
- sqlalchemy patterns: if the review scope involves SQLAlchemy session/query usage, db transaction/crud usage, or raw SQL usage, follow [references/sqlalchemy-rule.md](references/sqlalchemy-rule.md) to perform the review
|
||||
|
||||
## General Review Rules
|
||||
|
||||
### 1. Security Review
|
||||
|
||||
Check for:
|
||||
- SQL injection vulnerabilities
|
||||
- Server-Side Request Forgery (SSRF)
|
||||
- Command injection
|
||||
- Insecure deserialization
|
||||
- Hardcoded secrets/credentials
|
||||
- Improper authentication/authorization
|
||||
- Insecure direct object references
|
||||
|
||||
### 2. Performance Review
|
||||
|
||||
Check for:
|
||||
- N+1 queries
|
||||
- Missing database indexes
|
||||
- Memory leaks
|
||||
- Blocking operations in async code
|
||||
- Missing caching opportunities
|
||||
|
||||
### 3. Code Quality Review
|
||||
|
||||
Check for:
|
||||
- Code forward compatibility
|
||||
- Code duplication (DRY violations)
|
||||
- Functions doing too much (SRP violations)
|
||||
- Deep nesting / complex conditionals
|
||||
- Magic numbers/strings
|
||||
- Poor naming
|
||||
- Missing error handling
|
||||
- Incomplete type coverage
|
||||
|
||||
### 4. Testing Review
|
||||
|
||||
Check for:
|
||||
- Missing test coverage for new code
|
||||
- Tests that don't test behavior
|
||||
- Flaky test patterns
|
||||
- Missing edge cases
|
||||
|
||||
## Required Output Format
|
||||
|
||||
When this skill invoked, the response must exactly follow one of the two templates:
|
||||
|
||||
### Template A (any findings)
|
||||
|
||||
```markdown
|
||||
# Code Review Summary
|
||||
|
||||
Found <X> critical issues need to be fixed:
|
||||
|
||||
## 🔴 Critical (Must Fix)
|
||||
|
||||
### 1. <brief description of the issue>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<detailed explanation and references of the issue>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
1. <brief description of suggested fix>
|
||||
2. <code example> (optional, omit if not applicable)
|
||||
|
||||
---
|
||||
... (repeat for each critical issue) ...
|
||||
|
||||
Found <Y> suggestions for improvement:
|
||||
|
||||
## 🟡 Suggestions (Should Consider)
|
||||
|
||||
### 1. <brief description of the suggestion>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<detailed explanation and references of the suggestion>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
1. <brief description of suggested fix>
|
||||
2. <code example> (optional, omit if not applicable)
|
||||
|
||||
---
|
||||
... (repeat for each suggestion) ...
|
||||
|
||||
Found <Z> optional nits:
|
||||
|
||||
## 🟢 Nits (Optional)
|
||||
### 1. <brief description of the nit>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<explanation and references of the optional nit>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
- <minor suggestions>
|
||||
|
||||
---
|
||||
... (repeat for each nits) ...
|
||||
|
||||
## ✅ What's Good
|
||||
|
||||
- <Positive feedback on good patterns>
|
||||
```
|
||||
|
||||
- If there are no critical issues or suggestions or option nits or good points, just omit that section.
|
||||
- If the issue number is more than 10, summarize as "Found 10+ critical issues/suggestions/optional nits" and only output the first 10 items.
|
||||
- Don't compress the blank lines between sections; keep them as-is for readability.
|
||||
- If there is any issue requires code changes, append a brief follow-up question to ask whether the user wants to apply the fix(es) after the structured output. For example: "Would you like me to use the Suggested fix(es) to address these issues?"
|
||||
|
||||
### Template B (no issues)
|
||||
|
||||
```markdown
|
||||
## Code Review Summary
|
||||
✅ No issues found.
|
||||
```
|
||||
@ -0,0 +1,91 @@
|
||||
# Rule Catalog — Architecture
|
||||
|
||||
## Scope
|
||||
- Covers: controller/service/core-domain/libs/model layering, dependency direction, responsibility placement, observability-friendly flow.
|
||||
|
||||
## Rules
|
||||
|
||||
### Keep business logic out of controllers
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Controllers should parse input, call services, and return serialized responses. Business decisions inside controllers make behavior hard to reuse and test.
|
||||
- Suggested fix: Move domain/business logic into the service or core/domain layer. Keep controller handlers thin and orchestration-focused.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
@bp.post("/apps/<app_id>/publish")
|
||||
def publish_app(app_id: str):
|
||||
payload = request.get_json() or {}
|
||||
if payload.get("force") and current_user.role != "admin":
|
||||
raise ValueError("only admin can force publish")
|
||||
app = App.query.get(app_id)
|
||||
app.status = "published"
|
||||
db.session.commit()
|
||||
return {"result": "ok"}
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
@bp.post("/apps/<app_id>/publish")
|
||||
def publish_app(app_id: str):
|
||||
payload = PublishRequest.model_validate(request.get_json() or {})
|
||||
app_service.publish_app(app_id=app_id, force=payload.force, actor_id=current_user.id)
|
||||
return {"result": "ok"}
|
||||
```
|
||||
|
||||
### Preserve layer dependency direction
|
||||
- Category: best practices
|
||||
- Severity: critical
|
||||
- Description: Controllers may depend on services, and services may depend on core/domain abstractions. Reversing this direction (for example, core importing controller/web modules) creates cycles and leaks transport concerns into domain code.
|
||||
- Suggested fix: Extract shared contracts into core/domain or service-level modules and make upper layers depend on lower, not the reverse.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# core/policy/publish_policy.py
|
||||
from controllers.console.app import request_context
|
||||
|
||||
def can_publish() -> bool:
|
||||
return request_context.current_user.is_admin
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# core/policy/publish_policy.py
|
||||
def can_publish(role: str) -> bool:
|
||||
return role == "admin"
|
||||
|
||||
# service layer adapts web/user context to domain input
|
||||
allowed = can_publish(role=current_user.role)
|
||||
```
|
||||
|
||||
### Keep libs business-agnostic
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Modules under `api/libs/` should remain reusable, business-agnostic building blocks. They must not encode product/domain-specific rules, workflow orchestration, or business decisions.
|
||||
- Suggested fix:
|
||||
- If business logic appears in `api/libs/`, extract it into the appropriate `services/` or `core/` module and keep `libs` focused on generic, cross-cutting helpers.
|
||||
- Keep `libs` dependencies clean: avoid importing service/controller/domain-specific modules into `api/libs/`.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# api/libs/conversation_filter.py
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
def should_archive_conversation(conversation, tenant_id: str) -> bool:
|
||||
# Domain policy and service dependency are leaking into libs.
|
||||
service = ConversationService()
|
||||
if service.has_paid_plan(tenant_id):
|
||||
return conversation.idle_days > 90
|
||||
return conversation.idle_days > 30
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# api/libs/datetime_utils.py (business-agnostic helper)
|
||||
def older_than_days(idle_days: int, threshold_days: int) -> bool:
|
||||
return idle_days > threshold_days
|
||||
|
||||
# services/conversation_service.py (business logic stays in service/core)
|
||||
from libs.datetime_utils import older_than_days
|
||||
|
||||
def should_archive_conversation(conversation, tenant_id: str) -> bool:
|
||||
threshold_days = 90 if has_paid_plan(tenant_id) else 30
|
||||
return older_than_days(conversation.idle_days, threshold_days)
|
||||
```
|
||||
157
.agents/skills/backend-code-review/references/db-schema-rule.md
Normal file
157
.agents/skills/backend-code-review/references/db-schema-rule.md
Normal file
@ -0,0 +1,157 @@
|
||||
# Rule Catalog — DB Schema Design
|
||||
|
||||
## Scope
|
||||
- Covers: model/base inheritance, schema boundaries in model properties, tenant-aware schema design, index redundancy checks, dialect portability in models, and cross-database compatibility in migrations.
|
||||
- Does NOT cover: session lifecycle, transaction boundaries, and query execution patterns (handled by `sqlalchemy-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Do not query other tables inside `@property`
|
||||
- Category: [maintainability, performance]
|
||||
- Severity: critical
|
||||
- Description: A model `@property` must not open sessions or query other tables. This hides dependencies across models, tightly couples schema objects to data access, and can cause N+1 query explosions when iterating collections.
|
||||
- Suggested fix:
|
||||
- Keep model properties pure and local to already-loaded fields.
|
||||
- Move cross-table data fetching to service/repository methods.
|
||||
- For list/batch reads, fetch required related data explicitly (join/preload/bulk query) before rendering derived values.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
class Conversation(TypeBase):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
@property
|
||||
def app_name(self) -> str:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app = session.execute(select(App).where(App.id == self.app_id)).scalar_one()
|
||||
return app.name
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
class Conversation(TypeBase):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
@property
|
||||
def display_title(self) -> str:
|
||||
return self.name or "Untitled"
|
||||
|
||||
|
||||
# Service/repository layer performs explicit batch fetch for related App rows.
|
||||
```
|
||||
|
||||
### Prefer including `tenant_id` in model definitions
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: In multi-tenant domains, include `tenant_id` in schema definitions whenever the entity belongs to tenant-owned data. This improves data isolation safety and keeps future partitioning/sharding strategies practical as data volume grows.
|
||||
- Suggested fix:
|
||||
- Add a `tenant_id` column and ensure related unique/index constraints include tenant dimension when applicable.
|
||||
- Propagate `tenant_id` through service/repository contracts to keep access paths tenant-aware.
|
||||
- Exception: if a table is explicitly designed as non-tenant-scoped global metadata, document that design decision clearly.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class Dataset(TypeBase):
|
||||
__tablename__ = "datasets"
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class Dataset(TypeBase):
|
||||
__tablename__ = "datasets"
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
```
|
||||
|
||||
### Detect and avoid duplicate/redundant indexes
|
||||
- Category: performance
|
||||
- Severity: suggestion
|
||||
- Description: Review index definitions for leftmost-prefix redundancy. For example, index `(a, b, c)` can safely cover most lookups for `(a, b)`. Keeping both may increase write overhead and can mislead the optimizer into suboptimal execution plans.
|
||||
- Suggested fix:
|
||||
- Before adding an index, compare against existing composite indexes by leftmost-prefix rules.
|
||||
- Drop or avoid creating redundant prefixes unless there is a proven query-pattern need.
|
||||
- Apply the same review standard in both model `__table_args__` and migration index DDL.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
__table_args__ = (
|
||||
sa.Index("idx_msg_tenant_app", "tenant_id", "app_id"),
|
||||
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
|
||||
)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
__table_args__ = (
|
||||
# Keep the wider index unless profiling proves a dedicated short index is needed.
|
||||
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
|
||||
)
|
||||
```
|
||||
|
||||
### Avoid PostgreSQL-only dialect usage in models; wrap in `models.types`
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Model/schema definitions should avoid PostgreSQL-only constructs directly in business models. When database-specific behavior is required, encapsulate it in `api/models/types.py` using both PostgreSQL and MySQL dialect implementations, then consume that abstraction from model code.
|
||||
- Suggested fix:
|
||||
- Do not directly place dialect-only types/operators in model columns when a portable wrapper can be used.
|
||||
- Add or extend wrappers in `models.types` (for example, `AdjustedJSON`, `LongText`, `BinaryData`) to normalize behavior across PostgreSQL and MySQL.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class ToolConfig(TypeBase):
|
||||
__tablename__ = "tool_configs"
|
||||
config: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from models.types import AdjustedJSON
|
||||
|
||||
class ToolConfig(TypeBase):
|
||||
__tablename__ = "tool_configs"
|
||||
config: Mapped[dict] = mapped_column(AdjustedJSON(), nullable=False)
|
||||
```
|
||||
|
||||
### Guard migration incompatibilities with dialect checks and shared types
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Migration scripts under `api/migrations/versions/` must account for PostgreSQL/MySQL incompatibilities explicitly. For dialect-sensitive DDL or defaults, branch on the active dialect (for example, `conn.dialect.name == "postgresql"`), and prefer reusable compatibility abstractions from `models.types` where applicable.
|
||||
- Suggested fix:
|
||||
- In migration upgrades/downgrades, bind connection and branch by dialect for incompatible SQL fragments.
|
||||
- Reuse `models.types` wrappers in column definitions when that keeps behavior aligned with runtime models.
|
||||
- Avoid one-dialect-only migration logic unless there is a documented, deliberate compatibility exception.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"data_source_type",
|
||||
sa.String(255),
|
||||
server_default=sa.text("'database'::character varying"),
|
||||
nullable=False,
|
||||
)
|
||||
)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
def _is_pg(conn) -> bool:
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
conn = op.get_bind()
|
||||
default_expr = sa.text("'database'::character varying") if _is_pg(conn) else sa.text("'database'")
|
||||
|
||||
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("data_source_type", sa.String(255), server_default=default_expr, nullable=False)
|
||||
)
|
||||
```
|
||||
@ -0,0 +1,61 @@
|
||||
# Rule Catalog - Repositories Abstraction
|
||||
|
||||
## Scope
|
||||
- Covers: when to reuse existing repository abstractions, when to introduce new repositories, and how to preserve dependency direction between service/core and infrastructure implementations.
|
||||
- Does NOT cover: SQLAlchemy session lifecycle and query-shape specifics (handled by `sqlalchemy-rule.md`), and table schema/migration design (handled by `db-schema-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Introduce repositories abstraction
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: If a table/model already has a repository abstraction, all reads/writes/queries for that table should use the existing repository. If no repository exists, introduce one only when complexity justifies it, such as large/high-volume tables, repeated complex query logic, or likely storage-strategy variation.
|
||||
- Suggested fix:
|
||||
- First check `api/repositories`, `api/core/repositories`, and `api/extensions/*/repositories/` to verify whether the table/model already has a repository abstraction. If it exists, route all operations through it and add missing repository methods instead of bypassing it with ad-hoc SQLAlchemy access.
|
||||
- If no repository exists, add one only when complexity warrants it (for example, repeated complex queries, large data domains, or multiple storage strategies), while preserving dependency direction (service/core depends on abstraction; infra provides implementation).
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# Existing repository is ignored and service uses ad-hoc table queries.
|
||||
class AppService:
|
||||
def archive_app(self, app_id: str, tenant_id: str) -> None:
|
||||
app = self.session.execute(
|
||||
select(App).where(App.id == app_id, App.tenant_id == tenant_id)
|
||||
).scalar_one()
|
||||
app.archived = True
|
||||
self.session.commit()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Case A: Existing repository must be reused for all table operations.
|
||||
class AppService:
|
||||
def archive_app(self, app_id: str, tenant_id: str) -> None:
|
||||
app = self.app_repo.get_by_id(app_id=app_id, tenant_id=tenant_id)
|
||||
app.archived = True
|
||||
self.app_repo.save(app)
|
||||
|
||||
# If the query is missing, extend the existing abstraction.
|
||||
active_apps = self.app_repo.list_active_for_tenant(tenant_id=tenant_id)
|
||||
```
|
||||
- Bad:
|
||||
```python
|
||||
# No repository exists, but large-domain query logic is scattered in service code.
|
||||
class ConversationService:
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
|
||||
...
|
||||
# many filters/joins/pagination variants duplicated across services
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Case B: Introduce repository for large/complex domains or storage variation.
|
||||
class ConversationRepository(Protocol):
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: ...
|
||||
|
||||
class SqlAlchemyConversationRepository:
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
|
||||
...
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self, conversation_repo: ConversationRepository):
|
||||
self.conversation_repo = conversation_repo
|
||||
```
|
||||
139
.agents/skills/backend-code-review/references/sqlalchemy-rule.md
Normal file
139
.agents/skills/backend-code-review/references/sqlalchemy-rule.md
Normal file
@ -0,0 +1,139 @@
|
||||
# Rule Catalog — SQLAlchemy Patterns
|
||||
|
||||
## Scope
|
||||
- Covers: SQLAlchemy session and transaction lifecycle, query construction, tenant scoping, raw SQL boundaries, and write-path concurrency safeguards.
|
||||
- Does NOT cover: table/model schema and migration design details (handled by `db-schema-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Use Session context manager with explicit transaction control behavior
|
||||
- Category: best practices
|
||||
- Severity: critical
|
||||
- Description: Session and transaction lifecycle must be explicit and bounded on write paths. Missing commits can silently drop intended updates, while ad-hoc or long-lived transactions increase contention, lock duration, and deadlock risk.
|
||||
- Suggested fix:
|
||||
- Use **explicit `session.commit()`** after completing a related write unit.
|
||||
- Or use **`session.begin()` context manager** for automatic commit/rollback on a scoped block.
|
||||
- Keep transaction windows short: avoid network I/O, heavy computation, or unrelated work inside the transaction.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# Missing commit: write may never be persisted.
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
|
||||
# Long transaction: external I/O inside a DB transaction.
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
call_external_api()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Option 1: explicit commit.
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
session.commit()
|
||||
|
||||
# Option 2: scoped transaction with automatic commit/rollback.
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
|
||||
# Keep non-DB work outside transaction scope.
|
||||
call_external_api()
|
||||
```
|
||||
|
||||
### Enforce tenant_id scoping on shared-resource queries
|
||||
- Category: security
|
||||
- Severity: critical
|
||||
- Description: Reads and writes against shared tables must be scoped by `tenant_id` to prevent cross-tenant data leakage or corruption.
|
||||
- Suggested fix: Add `tenant_id` predicate to all tenant-owned entity queries and propagate tenant context through service/repository interfaces.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
### Prefer SQLAlchemy expressions over raw SQL by default
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: Raw SQL should be exceptional. ORM/Core expressions are easier to evolve, safer to compose, and more consistent with the codebase.
|
||||
- Suggested fix: Rewrite straightforward raw SQL into SQLAlchemy `select/update/delete` expressions; keep raw SQL only when required by clear technical constraints.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
row = session.execute(
|
||||
text("SELECT * FROM workflows WHERE id = :id AND tenant_id = :tenant_id"),
|
||||
{"id": workflow_id, "tenant_id": tenant_id},
|
||||
).first()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
row = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
### Protect write paths with concurrency safeguards
|
||||
- Category: quality
|
||||
- Severity: critical
|
||||
- Description: Multi-writer paths without explicit concurrency control can silently overwrite data. Choose the safeguard based on contention level, lock scope, and throughput cost instead of defaulting to one strategy.
|
||||
- Suggested fix:
|
||||
- **Optimistic locking**: Use when contention is usually low and retries are acceptable. Add a version (or updated_at) guard in `WHERE` and treat `rowcount == 0` as a conflict.
|
||||
- **Redis distributed lock**: Use when the critical section spans multiple steps/processes (or includes non-DB side effects) and you need cross-worker mutual exclusion.
|
||||
- **SELECT ... FOR UPDATE**: Use when contention is high on the same rows and strict in-transaction serialization is required. Keep transactions short to reduce lock wait/deadlock risk.
|
||||
- In all cases, scope by `tenant_id` and verify affected row counts for conditional writes.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# No tenant scope, no conflict detection, and no lock on a contested write path.
|
||||
session.execute(update(WorkflowRun).where(WorkflowRun.id == run_id).values(status="cancelled"))
|
||||
session.commit() # silently overwrites concurrent updates
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# 1) Optimistic lock (low contention, retry on conflict)
|
||||
result = session.execute(
|
||||
update(WorkflowRun)
|
||||
.where(
|
||||
WorkflowRun.id == run_id,
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.version == expected_version,
|
||||
)
|
||||
.values(status="cancelled", version=WorkflowRun.version + 1)
|
||||
)
|
||||
if result.rowcount == 0:
|
||||
raise WorkflowStateConflictError("stale version, retry")
|
||||
|
||||
# 2) Redis distributed lock (cross-worker critical section)
|
||||
lock_name = f"workflow_run_lock:{tenant_id}:{run_id}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
session.execute(
|
||||
update(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
|
||||
.values(status="cancelled")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# 3) Pessimistic lock with SELECT ... FOR UPDATE (high contention)
|
||||
run = session.execute(
|
||||
select(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
run.status = "cancelled"
|
||||
session.commit()
|
||||
```
|
||||
1
.claude/skills/backend-code-review
Symbolic link
1
.claude/skills/backend-code-review
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/backend-code-review
|
||||
88
.github/workflows/pyrefly-diff-comment.yml
vendored
Normal file
88
.github/workflows/pyrefly-diff-comment.yml
vendored
Normal file
@ -0,0 +1,88 @@
|
||||
name: Comment with Pyrefly Diff
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows:
|
||||
- Pyrefly Diff Check
|
||||
types:
|
||||
- completed
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with pyrefly diff
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Download pyrefly diff artifact
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const artifacts = await github.rest.actions.listWorkflowRunArtifacts({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
run_id: ${{ github.event.workflow_run.id }},
|
||||
});
|
||||
const match = artifacts.data.artifacts.find((artifact) =>
|
||||
artifact.name === 'pyrefly_diff'
|
||||
);
|
||||
if (!match) {
|
||||
throw new Error('pyrefly_diff artifact not found');
|
||||
}
|
||||
const download = await github.rest.actions.downloadArtifact({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
artifact_id: match.id,
|
||||
archive_format: 'zip',
|
||||
});
|
||||
fs.writeFileSync('pyrefly_diff.zip', Buffer.from(download.data));
|
||||
|
||||
- name: Unzip artifact
|
||||
run: unzip -o pyrefly_diff.zip
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
let diff = fs.readFileSync('pyrefly_diff.txt', { encoding: 'utf8' });
|
||||
let prNumber = null;
|
||||
try {
|
||||
prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10);
|
||||
} catch (err) {
|
||||
// Fallback to workflow_run payload if artifact is missing or incomplete.
|
||||
const prs = context.payload.workflow_run.pull_requests || [];
|
||||
if (prs.length > 0 && prs[0].number) {
|
||||
prNumber = prs[0].number;
|
||||
}
|
||||
}
|
||||
if (!prNumber) {
|
||||
throw new Error('PR number not found in artifact or workflow_run payload');
|
||||
}
|
||||
|
||||
const MAX_CHARS = 65000;
|
||||
if (diff.length > MAX_CHARS) {
|
||||
diff = diff.slice(0, MAX_CHARS);
|
||||
diff = diff.slice(0, diff.lastIndexOf('\\n'));
|
||||
diff += '\\n\\n... (truncated) ...';
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
100
.github/workflows/pyrefly-diff.yml
vendored
Normal file
100
.github/workflows/pyrefly-diff.yml
vendored
Normal file
@ -0,0 +1,100 @@
|
||||
name: Pyrefly Diff Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'api/**/*.py'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pyrefly-diff:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Prepare diagnostics extractor
|
||||
run: |
|
||||
git show ${{ github.event.pull_request.head.sha }}:api/libs/pyrefly_diagnostics.py > /tmp/pyrefly_diagnostics.py
|
||||
|
||||
- name: Run pyrefly on PR branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly check 2>&1 \
|
||||
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_pr.txt || true
|
||||
|
||||
- name: Checkout base branch
|
||||
run: git checkout ${{ github.base_ref }}
|
||||
|
||||
- name: Run pyrefly on base branch
|
||||
run: |
|
||||
uv run --directory api --dev pyrefly check 2>&1 \
|
||||
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_base.txt || true
|
||||
|
||||
- name: Compute diff
|
||||
run: |
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload pyrefly diff
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pyrefly_diff
|
||||
path: |
|
||||
pyrefly_diff.txt
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
let diff = fs.readFileSync('pyrefly_diff.txt', { encoding: 'utf8' });
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
|
||||
const MAX_CHARS = 65000;
|
||||
if (diff.length > MAX_CHARS) {
|
||||
diff = diff.slice(0, MAX_CHARS);
|
||||
diff = diff.slice(0, diff.lastIndexOf('\n'));
|
||||
diff += '\n\n... (truncated) ...';
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? [
|
||||
'### Pyrefly Diff',
|
||||
'<details>',
|
||||
'<summary>base → PR</summary>',
|
||||
'',
|
||||
'```diff',
|
||||
diff,
|
||||
'```',
|
||||
'</details>',
|
||||
].join('\n')
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
63
.github/workflows/web-tests.yml
vendored
63
.github/workflows/web-tests.yml
vendored
@ -3,14 +3,22 @@ name: Web Tests
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: web-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@ -39,7 +47,58 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test:ci
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
include-hidden-files: true
|
||||
retention-days: 1
|
||||
|
||||
merge-reports:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: pnpm vitest --merge-reports --coverage --silent=passed-only
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
||||
5
Makefile
5
Makefile
@ -68,10 +68,9 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
|
||||
@echo "📝 Running type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@cd api && uv run ty check
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
test:
|
||||
@ -132,7 +131,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@ -1,9 +1,5 @@
|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">Self-hosting</a> ·
|
||||
|
||||
@ -29,6 +29,8 @@ ignore_imports =
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
@ -50,14 +52,9 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
|
||||
@ -91,7 +88,6 @@ forbidden_modules =
|
||||
core.logging
|
||||
core.mcp
|
||||
core.memory
|
||||
core.model_manager
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
@ -105,29 +101,16 @@ forbidden_modules =
|
||||
core.variables
|
||||
ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.entities -> configs
|
||||
core.workflow.nodes.http_request.executor -> configs
|
||||
core.workflow.nodes.http_request.node -> configs
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> models.model
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
@ -140,36 +123,19 @@ ignore_imports =
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.start.entities -> core.app.app_config.entities
|
||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.variables.segments
|
||||
core.workflow.nodes.loop.entities -> core.variables.types
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.tool.tool_node -> models
|
||||
core.workflow.nodes.agent.agent_node -> models.model
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
|
||||
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
|
||||
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.node -> core.helper.code_executor
|
||||
core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
|
||||
@ -178,7 +144,6 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
@ -187,71 +152,13 @@ ignore_imports =
|
||||
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.errors
|
||||
core.workflow.conversation_variable_updater -> core.variables
|
||||
core.workflow.graph_engine.entities.commands -> core.variables.variables
|
||||
core.workflow.nodes.agent.agent_node -> core.variables.segments
|
||||
core.workflow.nodes.answer.answer_node -> core.variables
|
||||
core.workflow.nodes.code.code_node -> core.variables.segments
|
||||
core.workflow.nodes.code.code_node -> core.variables.types
|
||||
core.workflow.nodes.code.entities -> core.variables.types
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
|
||||
core.workflow.nodes.document_extractor.node -> core.variables
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||
core.workflow.nodes.human_input.entities -> core.variables.consts
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
|
||||
core.workflow.nodes.list_operator.node -> core.variables
|
||||
core.workflow.nodes.list_operator.node -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.variables
|
||||
core.workflow.nodes.loop.loop_node -> core.variables
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.exc -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.segments
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.variables
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.types
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.variables
|
||||
core.workflow.nodes.variable_aggregator.entities -> core.variables.types
|
||||
core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
|
||||
core.workflow.nodes.variable_assigner.v1.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
|
||||
core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
|
||||
core.workflow.runtime.read_only_wrappers -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables
|
||||
core.workflow.runtime.variable_pool -> core.variables.consts
|
||||
core.workflow.runtime.variable_pool -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables.variables
|
||||
core.workflow.utils.condition.processor -> core.variables
|
||||
core.workflow.utils.condition.processor -> core.variables.segments
|
||||
core.workflow.variable_loader -> core.variables
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
@ -259,7 +166,7 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider_ids
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.llm.node -> models.model
|
||||
core.workflow.workflow_entry -> models.enums
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
|
||||
@ -42,7 +42,7 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Optional: start the worker service (async tasks, runs from `api`).
|
||||
1. Start the worker service (async and scheduler tasks, runs from `api`).
|
||||
|
||||
```bash
|
||||
./dev/start-worker
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -765,7 +765,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -15,11 +15,11 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
from core.workflow.variables.segment_group import SegmentGroup
|
||||
from core.workflow.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
@ -133,11 +133,11 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
full_content=fields.Raw(attribute=_serialize_full_content),
|
||||
)
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = {
|
||||
**_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
"value": fields.Raw(attribute=_serialize_var_value),
|
||||
"full_content": fields.Raw(attribute=_serialize_full_content),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"id": fields.String,
|
||||
|
||||
@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
@ -44,6 +44,7 @@ from core.errors.error import (
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.app_fields import (
|
||||
app_detail_fields_with_site,
|
||||
deleted_tool_fields,
|
||||
@ -225,7 +226,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode, InstalledApp
|
||||
@ -100,6 +101,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -36,9 +36,9 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check account initialization
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.status == AccountStatus.UNINITIALIZED:
|
||||
@ -214,9 +214,9 @@ def cloud_utm_record(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view: Callable[P, R]):
|
||||
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check setup
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
|
||||
@ -137,7 +137,7 @@ class FilePreviewApi(Resource):
|
||||
if args.as_attachment:
|
||||
encoded_filename = quote(upload_file.name)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
response.headers["Content-Type"] = "application/octet-stream"
|
||||
|
||||
enforce_download_for_html(
|
||||
response,
|
||||
|
||||
@ -64,6 +64,10 @@ class ToolFileApi(Resource):
|
||||
|
||||
if not stream or not tool_file:
|
||||
raise NotFound("file is not found")
|
||||
|
||||
except NotFound:
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
|
||||
@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import mcp_ns
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types as mcp_types
|
||||
from core.mcp.server.streamable_http import handle_mcp_request
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
|
||||
@ -31,6 +31,7 @@ from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
@ -280,7 +281,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@ -3,7 +3,8 @@ from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
@ -17,7 +18,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.dataset import Dataset, Pipeline
|
||||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
@ -65,6 +66,12 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Get query parameter to determine published or draft
|
||||
is_published: bool = request.args.get("is_published", default=True, type=bool)
|
||||
|
||||
@ -104,6 +111,12 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
@ -161,6 +174,12 @@ class PipelineRunApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
|
||||
@ -24,6 +24,7 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -121,6 +122,6 @@ class WorkflowTaskStopApi(WebApiResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
|
||||
@ -2,12 +2,12 @@ from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
from jsonschema import Draft7Validator, SchemaError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.workflow.file import FileUploadConfig
|
||||
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@ -90,61 +90,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
|
||||
|
||||
|
||||
class VariableEntityType(StrEnum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
CHECKBOX = "checkbox"
|
||||
JSON_OBJECT = "json_object"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Variable Entity.
|
||||
"""
|
||||
|
||||
# `variable` records the name of the variable in user inputs.
|
||||
variable: str
|
||||
label: str
|
||||
description: str = ""
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
hide: bool = False
|
||||
default: Any = None
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, v: Any) -> str:
|
||||
return v or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict | None) -> dict | None:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
class RagPipelineVariableEntity(WorkflowVariableEntity):
|
||||
"""
|
||||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
@ -314,7 +260,7 @@ class AppConfig(BaseModel):
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: AppAdditionalFeatures | None = None
|
||||
variables: list[VariableEntity] = []
|
||||
variables: list[WorkflowVariableEntity] = []
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
|
||||
@ -26,7 +26,6 @@ from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.sandbox import Sandbox
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -35,6 +34,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.variables.variables import Variable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -850,16 +850,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle retriever resources events."""
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
return
|
||||
yield # Make this a generator
|
||||
yield from ()
|
||||
|
||||
def _handle_annotation_reply_event(
|
||||
self, event: QueueAnnotationReplyEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle annotation reply events."""
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
return
|
||||
yield # Make this a generator
|
||||
yield from ()
|
||||
|
||||
def _handle_message_replace_event(
|
||||
self, event: QueueMessageReplaceEvent, **kwargs
|
||||
|
||||
@ -175,7 +175,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file import File, FileUploadConfig
|
||||
@ -12,13 +11,14 @@ from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
NoopDraftVariableSaver,
|
||||
)
|
||||
from core.workflow.variables.input_entities import VariableEntityType
|
||||
from factories import file_factory
|
||||
from libs.orjson import orjson_dumps
|
||||
from models import Account, EndUser
|
||||
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
|
||||
@ -122,7 +122,7 @@ class AppQueueManager(ABC):
|
||||
"""Attach the live graph runtime state reference for downstream consumers."""
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
|
||||
@ -49,7 +49,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import (
|
||||
@ -62,6 +61,7 @@ from core.workflow.enums import (
|
||||
from core.workflow.file import FILE_MODEL_IDENTITY, File
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -11,7 +11,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
@ -21,6 +20,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document, Pipeline
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
|
||||
from core.variables import VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.variables import VariableBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
5
api/core/app/llm/__init__.py
Normal file
5
api/core/app/llm/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""LLM-related application services."""
|
||||
|
||||
from .quota import deduct_llm_quota, ensure_llm_quota_available
|
||||
|
||||
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
|
||||
110
api/core/app/llm/model_access.py
Normal file
110
api/core/app/llm/model_access.py
Normal file
@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
|
||||
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_manager = provider_manager or ProviderManager()
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider_name} does not exist.")
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(model_type=ModelType.LLM, model=model_name)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model_name)
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
class DifyModelFactory:
|
||||
tenant_id: str
|
||||
model_manager: ModelManager
|
||||
|
||||
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.model_manager = model_manager or ModelManager()
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
return self.model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
|
||||
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
|
||||
return (
|
||||
DifyCredentialsProvider(tenant_id=tenant_id),
|
||||
DifyModelFactory(tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.stop = tuple(stop)
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
93
api/core/app/llm/quota.py
Normal file
93
api/core/app/llm/quota.py
Normal file
@ -0,0 +1,93 @@
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
@ -168,7 +168,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
@ -181,7 +181,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
@ -294,7 +294,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
self._task_state.llm_result.message.get_text_content()
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
@ -408,7 +408,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.get_text_content().strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
|
||||
|
||||
from .llm_quota import LLMQuotaLayer
|
||||
from .observability import ObservabilityLayer
|
||||
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
|
||||
__all__ = [
|
||||
"LLMQuotaLayer",
|
||||
"ObservabilityLayer",
|
||||
"PersistenceWorkflowInfo",
|
||||
"WorkflowPersistenceLayer",
|
||||
|
||||
128
api/core/app/workflow/layers/llm_quota.py
Normal file
128
api/core/app/workflow/layers/llm_quota.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
LLM quota deduction layer for GraphEngine.
|
||||
|
||||
This layer centralizes model-quota deduction outside node implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class LLMQuotaLayer(GraphEngineLayer):
|
||||
"""Graph layer that applies LLM quota deduction after node execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
_ = event
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
_ = error
|
||||
|
||||
@override
|
||||
def on_node_run_start(self, node: Node) -> None:
|
||||
if self._abort_sent:
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
return
|
||||
|
||||
try:
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||
|
||||
@override
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
return
|
||||
|
||||
try:
|
||||
deduct_llm_quota(
|
||||
tenant_id=node.tenant_id,
|
||||
model_instance=model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
|
||||
except Exception:
|
||||
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
|
||||
|
||||
@staticmethod
|
||||
def _set_stop_event(node: Node) -> None:
|
||||
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
|
||||
try:
|
||||
self.command_channel.send_command(
|
||||
AbortCommand(
|
||||
command_type=CommandType.ABORT,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
self._abort_sent = True
|
||||
except Exception:
|
||||
logger.exception("Failed to send quota abort command")
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case NodeType.LLM:
|
||||
return cast("LLMNode", node).model_instance
|
||||
case NodeType.PARAMETER_EXTRACTOR:
|
||||
return cast("ParameterExtractorNode", node).model_instance
|
||||
case NodeType.QUESTION_CLASSIFIER:
|
||||
return cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
@ -1,37 +1,97 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.memory import PromptMessageMemory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.index_processor.index_processor import IndexProcessor
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.enums import NodeType, SystemVariableKey
|
||||
from core.workflow.file.file_manager import file_manager
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
|
||||
from core.workflow.nodes.code.entities import CodeLanguage
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
Jinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
app_id: str,
|
||||
node_data_memory: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory or not conversation_id:
|
||||
return None
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
||||
class DefaultWorkflowCodeExecutor:
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
language: CodeLanguage,
|
||||
code: str,
|
||||
inputs: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
return CodeExecutor.execute_workflow_code_template(
|
||||
language=language,
|
||||
code=code,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
def is_execution_error(self, error: Exception) -> bool:
|
||||
return isinstance(error, CodeExecutionError)
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
@ -45,23 +105,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
template_transform_max_output_length: int | None = None,
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
document_extractor_unstructured_api_config: UnstructuredApiConfig | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
|
||||
)
|
||||
self._code_limits = code_limits or CodeNodeLimits(
|
||||
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
|
||||
self._code_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
min_number=dify_config.CODE_MIN_NUMBER,
|
||||
@ -71,21 +119,27 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = (
|
||||
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
)
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
self._template_renderer = CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = ToolFileManager
|
||||
self._http_request_file_manager = file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
self._document_extractor_unstructured_api_config = (
|
||||
document_extractor_unstructured_api_config
|
||||
or UnstructuredApiConfig(
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
|
||||
)
|
||||
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
|
||||
)
|
||||
self._http_request_config = build_http_request_config(
|
||||
max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
|
||||
max_write_timeout=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
max_binary_size=dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE,
|
||||
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@ -126,7 +180,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
code_providers=self._code_providers,
|
||||
code_limits=self._code_limits,
|
||||
)
|
||||
|
||||
@ -146,11 +199,45 @@ class DifyNodeFactory(NodeFactory):
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
http_request_config=self._http_request_config,
|
||||
http_client=self._http_request_http_client,
|
||||
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_INDEX:
|
||||
return KnowledgeIndexNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
index_processor=IndexProcessor(),
|
||||
summary_index_service=SummaryIndex(),
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
return DatasourceNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
datasource_manager=DatasourceManager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
@ -169,9 +256,95 @@ class DifyNodeFactory(NodeFactory):
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
|
||||
node_data_model = ModelConfig.model_validate(node_data["model"])
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.stop = tuple(stop)
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance
|
||||
|
||||
def _build_memory_for_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_data: Mapping[str, Any],
|
||||
model_instance: ModelInstance,
|
||||
) -> PromptMessageMemory | None:
|
||||
raw_memory_config = node_data.get("memory")
|
||||
if raw_memory_config is None:
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
conversation_id = (
|
||||
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
|
||||
)
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
node_data_memory=node_memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@ -1,16 +1,39 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from threading import Lock
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.file import File
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -103,3 +126,238 @@ class DatasourceManager:
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
|
||||
@classmethod
|
||||
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
|
||||
datasource_runtime = cls.get_datasource_runtime(
|
||||
provider_id=provider_id,
|
||||
datasource_name=datasource_name,
|
||||
tenant_id=tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
return datasource_runtime.get_icon_url(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def stream_online_results(
|
||||
cls,
|
||||
*,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[DatasourceMessage, None, Any]:
|
||||
"""
|
||||
Pull-based streaming of domain messages from datasource plugins.
|
||||
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
|
||||
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
|
||||
"""
|
||||
ds_type = DatasourceProviderType.value_of(datasource_type)
|
||||
runtime = cls.get_datasource_runtime(
|
||||
provider_id=provider_id,
|
||||
datasource_name=datasource_name,
|
||||
tenant_id=tenant_id,
|
||||
datasource_type=ds_type,
|
||||
)
|
||||
|
||||
dsp_service = DatasourceProviderService()
|
||||
credentials = dsp_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
|
||||
if credentials:
|
||||
doc_runtime.runtime.credentials = credentials
|
||||
if datasource_param is None:
|
||||
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
|
||||
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
|
||||
user_id=user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_param.workspace_id,
|
||||
page_id=datasource_param.page_id,
|
||||
type=datasource_param.type,
|
||||
),
|
||||
provider_type=ds_type,
|
||||
)
|
||||
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
|
||||
if credentials:
|
||||
drive_runtime.runtime.credentials = credentials
|
||||
if online_drive_request is None:
|
||||
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
|
||||
inner_gen = drive_runtime.online_drive_download_file(
|
||||
user_id=user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=online_drive_request.id,
|
||||
bucket=online_drive_request.bucket,
|
||||
),
|
||||
provider_type=ds_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
|
||||
|
||||
# Bridge through to caller while preserving generator return contract
|
||||
yield from inner_gen
|
||||
# No structured final data here; node/adapter will assemble outputs
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(
|
||||
cls,
|
||||
*,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: Any,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
|
||||
ds_type = DatasourceProviderType.value_of(datasource_type)
|
||||
|
||||
messages = cls.stream_online_results(
|
||||
user_id=user_id,
|
||||
datasource_name=datasource_name,
|
||||
datasource_type=datasource_type,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credential_id=credential_id,
|
||||
datasource_param=datasource_param,
|
||||
online_drive_request=online_drive_request,
|
||||
)
|
||||
|
||||
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
|
||||
)
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
file_out: File | None = None
|
||||
|
||||
for message in transformed:
|
||||
mtype = message.type
|
||||
if mtype in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
wanted_ds_type = ds_type in {
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
}
|
||||
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
|
||||
url = message.message.text
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
|
||||
)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if not datasource_file:
|
||||
raise ValueError(
|
||||
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
|
||||
)
|
||||
mime_type = datasource_file.mimetype
|
||||
if datasource_file is not None:
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(mime_type),
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"url": url,
|
||||
}
|
||||
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
|
||||
elif mtype == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
|
||||
elif mtype == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
|
||||
)
|
||||
elif mtype == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
name = message.message.variable_name
|
||||
value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
assert isinstance(value, str), "stream variable_value must be str"
|
||||
variables[name] = variables.get(name, "") + value
|
||||
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
|
||||
else:
|
||||
variables[name] = value
|
||||
elif mtype == DatasourceMessage.MessageType.FILE:
|
||||
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
|
||||
f = message.meta.get("file")
|
||||
if isinstance(f, File):
|
||||
file_out = f
|
||||
else:
|
||||
pass
|
||||
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
|
||||
variable_pool.add([node_id, "file"], file_out)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={**variables},
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file_out,
|
||||
"datasource_type": ds_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = (
|
||||
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
return file_info
|
||||
|
||||
@ -379,4 +379,11 @@ class OnlineDriveDownloadFileRequest(BaseModel):
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the file")
|
||||
bucket: str | None = Field(None, description="The name of the bucket")
|
||||
bucket: str = Field("", description="The name of the bucket")
|
||||
|
||||
@field_validator("bucket", mode="before")
|
||||
@classmethod
|
||||
def _coerce_bucket(cls, v) -> str:
|
||||
if v is None:
|
||||
return ""
|
||||
return str(v)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
@ -14,6 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr
|
||||
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.workflow.nodes.code.entities import CodeLanguage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
@ -40,12 +40,6 @@ class CodeExecutionResponse(BaseModel):
|
||||
data: Data
|
||||
|
||||
|
||||
class CodeLanguage(StrEnum):
|
||||
PYTHON3 = "python3"
|
||||
JINJA2 = "jinja2"
|
||||
JAVASCRIPT = "javascript"
|
||||
|
||||
|
||||
def _build_code_executor_client() -> httpx.Client:
|
||||
return httpx.Client(
|
||||
verify=CODE_EXECUTION_SSL_VERIFY,
|
||||
|
||||
@ -5,7 +5,7 @@ from base64 import b64encode
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.variables.utils import dumps_with_segments
|
||||
from core.workflow.variables.utils import dumps_with_segments
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
|
||||
@ -4,10 +4,10 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types as mcp_types
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@ -35,9 +35,12 @@ class ModelInstance:
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model = model
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
||||
self.load_balancing_manager = self._get_load_balancing_manager(
|
||||
configuration=provider_model_bundle.configuration,
|
||||
@ -163,7 +166,7 @@ class ModelInstance:
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
@ -191,7 +194,7 @@ class ModelInstance:
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
@ -215,7 +218,7 @@ class ModelInstance:
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
@ -243,7 +246,7 @@ class ModelInstance:
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
user=user,
|
||||
@ -264,7 +267,7 @@ class ModelInstance:
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
@ -294,7 +297,7 @@ class ModelInstance:
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
@ -328,7 +331,7 @@ class ModelInstance:
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
@ -352,7 +355,7 @@ class ModelInstance:
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
@ -373,7 +376,7 @@ class ModelInstance:
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
@ -396,7 +399,7 @@ class ModelInstance:
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
@ -469,7 +472,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
model=self.model_name, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
|
||||
|
||||
3
api/core/model_runtime/memory/__init__.py
Normal file
3
api/core/model_runtime/memory/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
|
||||
|
||||
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]
|
||||
18
api/core/model_runtime/memory/prompt_message_memory.py
Normal file
18
api/core/model_runtime/memory/prompt_message_memory.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
|
||||
|
||||
|
||||
class PromptMessageMemory(Protocol):
|
||||
"""Port for loading memory as prompt messages."""
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Return historical prompt messages constrained by token/message limits."""
|
||||
...
|
||||
@ -83,19 +83,21 @@ def _merge_tool_call_delta(
|
||||
tool_call.function.arguments += delta.function.arguments
|
||||
|
||||
|
||||
def _build_llm_result_from_first_chunk(
|
||||
def _build_llm_result_from_chunks(
|
||||
model: str,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
chunks: Iterator[LLMResultChunk],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Build a single `LLMResult` from the first returned chunk.
|
||||
Build a single `LLMResult` by accumulating all returned chunks.
|
||||
|
||||
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
|
||||
Some models only support streaming output (e.g. Qwen3 open-source edition)
|
||||
and the plugin side may still implement the response via a chunked stream,
|
||||
so all chunks must be consumed and concatenated into a single ``LLMResult``.
|
||||
|
||||
Note:
|
||||
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
|
||||
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
|
||||
The ``usage`` is taken from the last chunk that carries it, which is the
|
||||
typical convention for streaming responses (the final chunk contains the
|
||||
aggregated token counts).
|
||||
"""
|
||||
content = ""
|
||||
content_list: list[PromptMessageContentUnionTypes] = []
|
||||
@ -104,24 +106,27 @@ def _build_llm_result_from_first_chunk(
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
try:
|
||||
first_chunk = next(chunks, None)
|
||||
if first_chunk is not None:
|
||||
if isinstance(first_chunk.delta.message.content, str):
|
||||
content += first_chunk.delta.message.content
|
||||
elif isinstance(first_chunk.delta.message.content, list):
|
||||
content_list.extend(first_chunk.delta.message.content)
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content += chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
content_list.extend(chunk.delta.message.content)
|
||||
|
||||
if first_chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
|
||||
if chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = first_chunk.system_fingerprint
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception:
|
||||
logger.exception("Error while consuming non-stream plugin chunk iterator.")
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
for _ in chunks:
|
||||
pass
|
||||
except Exception:
|
||||
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
|
||||
# Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections).
|
||||
close = getattr(chunks, "close", None)
|
||||
if callable(close):
|
||||
close()
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
@ -174,7 +179,7 @@ def _normalize_non_stream_plugin_result(
|
||||
) -> LLMResult:
|
||||
if isinstance(result, LLMResult):
|
||||
return result
|
||||
return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result)
|
||||
return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result)
|
||||
|
||||
|
||||
def _increase_tool_call(
|
||||
|
||||
@ -14,6 +14,7 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
DIFY_APP_ID,
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_INPUT_MESSAGE,
|
||||
GEN_AI_OUTPUT_MESSAGE,
|
||||
@ -99,6 +100,16 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"Aliyun get project url failed: {str(e)}")
|
||||
|
||||
def _extract_app_id(self, trace_info: BaseTraceInfo) -> str:
|
||||
"""Extract app_id from trace_info, trying metadata first then message_data."""
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if app_id:
|
||||
return str(app_id)
|
||||
message_data = getattr(trace_info, "message_data", None)
|
||||
if message_data is not None:
|
||||
return str(getattr(message_data, "app_id", ""))
|
||||
return ""
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_metadata = TraceMetadata(
|
||||
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
|
||||
@ -143,13 +154,16 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_str,
|
||||
),
|
||||
DIFY_APP_ID: self._extract_app_id(trace_info),
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
@ -441,6 +455,8 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
|
||||
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
|
||||
|
||||
app_id = self._extract_app_id(trace_info)
|
||||
|
||||
if message_span_id:
|
||||
message_span = SpanData(
|
||||
trace_id=trace_metadata.trace_id,
|
||||
@ -449,13 +465,16 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||
outputs=outputs_json,
|
||||
),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
|
||||
outputs=outputs_json,
|
||||
),
|
||||
DIFY_APP_ID: app_id,
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER,
|
||||
@ -469,13 +488,16 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
name="workflow",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes=create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
attributes={
|
||||
**create_common_span_attributes(
|
||||
session_id=trace_metadata.session_id,
|
||||
user_id=trace_metadata.user_id,
|
||||
span_kind=GenAISpanKind.CHAIN,
|
||||
inputs=inputs_json,
|
||||
outputs=outputs_json,
|
||||
),
|
||||
**({DIFY_APP_ID: app_id} if message_span_id is None else {}),
|
||||
},
|
||||
status=status,
|
||||
links=trace_metadata.links,
|
||||
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
|
||||
|
||||
@ -3,6 +3,9 @@ from typing import Final
|
||||
|
||||
ACS_ARMS_SERVICE_FEATURE: Final[str] = "acs.arms.service.feature"
|
||||
|
||||
# Dify-specific attributes
|
||||
DIFY_APP_ID: Final[str] = "dify.app_id"
|
||||
|
||||
# Public attributes
|
||||
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
|
||||
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
|
||||
|
||||
@ -155,6 +155,26 @@ def wrap_span_metadata(metadata, **kwargs):
|
||||
return metadata
|
||||
|
||||
|
||||
# Mapping from NodeType string values to OpenInference span kinds.
|
||||
# NodeType values not listed here default to CHAIN.
|
||||
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
|
||||
"llm": OpenInferenceSpanKindValues.LLM,
|
||||
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
|
||||
"tool": OpenInferenceSpanKindValues.TOOL,
|
||||
"agent": OpenInferenceSpanKindValues.AGENT,
|
||||
}
|
||||
|
||||
|
||||
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
|
||||
"""Return the OpenInference span kind for a given workflow node type.
|
||||
|
||||
Covers every ``NodeType`` enum value. Nodes that do not have a
|
||||
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
|
||||
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
|
||||
"""
|
||||
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
@ -289,9 +309,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
)
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
span_kind = _get_node_span_kind(node_execution.node_type)
|
||||
if node_execution.node_type == "llm":
|
||||
span_kind = OpenInferenceSpanKindValues.LLM
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
@ -306,12 +325,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
elif node_execution.node_type == "dataset_retrieval":
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER
|
||||
elif node_execution.node_type == "tool":
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
node_span = self.tracer.start_span(
|
||||
|
||||
@ -14,10 +14,9 @@ class BaseTraceInstance(ABC):
|
||||
Base trace instance for ops trace services
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, trace_config: BaseTracingConfig):
|
||||
"""
|
||||
Abstract initializer for the trace instance.
|
||||
Initializer for the trace instance.
|
||||
Distribute trace tasks by matching entities
|
||||
"""
|
||||
self.trace_config = trace_config
|
||||
|
||||
@ -129,11 +129,11 @@ class LangfuseSpan(BaseModel):
|
||||
default=None,
|
||||
description="The id of the user that triggered the execution. Used to provide user-level analytics.",
|
||||
)
|
||||
start_time: datetime | str | None = Field(
|
||||
start_time: datetime | None = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the span started, defaults to the current time.",
|
||||
)
|
||||
end_time: datetime | str | None = Field(
|
||||
end_time: datetime | None = Field(
|
||||
default=None,
|
||||
description="The time at which the span ended. Automatically set by span.end().",
|
||||
)
|
||||
@ -146,7 +146,7 @@ class LangfuseSpan(BaseModel):
|
||||
description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated "
|
||||
"via the API.",
|
||||
)
|
||||
level: str | None = Field(
|
||||
level: LevelEnum | None = Field(
|
||||
default=None,
|
||||
description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of "
|
||||
"traces with elevated error levels and for highlighting in the UI.",
|
||||
@ -222,16 +222,16 @@ class LangfuseGeneration(BaseModel):
|
||||
default=None,
|
||||
description="Identifier of the generation. Useful for sorting/filtering in the UI.",
|
||||
)
|
||||
start_time: datetime | str | None = Field(
|
||||
start_time: datetime | None = Field(
|
||||
default_factory=datetime.now,
|
||||
description="The time at which the generation started, defaults to the current time.",
|
||||
)
|
||||
completion_start_time: datetime | str | None = Field(
|
||||
completion_start_time: datetime | None = Field(
|
||||
default=None,
|
||||
description="The time at which the completion started (streaming). Set it to get latency analytics broken "
|
||||
"down into time until completion started and completion duration.",
|
||||
)
|
||||
end_time: datetime | str | None = Field(
|
||||
end_time: datetime | None = Field(
|
||||
default=None,
|
||||
description="The time at which the generation ended. Automatically set by generation.end().",
|
||||
)
|
||||
|
||||
@ -41,8 +41,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
def __getitem__(self, key: str) -> dict[str, Any]:
|
||||
match key:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from core.ops.entities.config_entity import LangfuseConfig
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
@ -149,7 +149,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
raise KeyError(f"Unsupported tracing provider: {key}")
|
||||
|
||||
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
@ -9,7 +9,7 @@ from models.engine import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
def filter_none_values(data: dict):
|
||||
def filter_none_values(data: dict[str, Any]) -> dict[str, Any]:
|
||||
new_data = {}
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
|
||||
@ -2,6 +2,7 @@ import tempfile
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
@ -29,7 +30,6 @@ from core.plugin.entities.request import (
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
@ -120,26 +118,37 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
if isinstance(response, Generator):
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle_non_streaming(response)
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=response.model,
|
||||
prompt_messages=[],
|
||||
system_fingerprint=response.system_fingerprint,
|
||||
structured_output=response.structured_output,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=response.message,
|
||||
usage=response.usage,
|
||||
finish_reason="",
|
||||
),
|
||||
)
|
||||
|
||||
return handle_non_streaming(response)
|
||||
|
||||
@classmethod
|
||||
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import cast
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -44,7 +45,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
@ -59,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||
@ -71,6 +74,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
@ -85,7 +89,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
@ -111,6 +116,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
parser=parser,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if query:
|
||||
@ -146,7 +152,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
@ -198,8 +205,13 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory,
|
||||
memory_config,
|
||||
prompt_messages,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
if files and query is not None:
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
@ -276,7 +288,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
parser: PromptTemplateParser,
|
||||
prompt_inputs: Mapping[str, str],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#histories#" in parser.variable_keys:
|
||||
@ -286,7 +299,11 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
rest_tokens = self._calculate_rest_token(
|
||||
[tmp_human_message],
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
|
||||
@ -41,13 +41,15 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
if not self.memory:
|
||||
return prompt_messages
|
||||
|
||||
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
|
||||
max_token_limit = self._calculate_rest_token(self.prompt_messages, model_config=self.model_config)
|
||||
|
||||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
self.history_messages,
|
||||
)
|
||||
if curr_message_tokens <= max_token_limit:
|
||||
return self.history_messages
|
||||
@ -63,7 +65,9 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
# a message is start with UserPromptMessage
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
prompt_messages,
|
||||
)
|
||||
# if current message token is overflow, drop all the prompts in current message and break
|
||||
if curr_message_tokens > max_token_limit:
|
||||
|
||||
@ -4,45 +4,83 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
def _resolve_model_runtime(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> tuple[ModelInstance, AIModelEntity]:
|
||||
if model_instance is None:
|
||||
if model_config is None:
|
||||
raise ValueError("Either model_config or model_instance must be provided.")
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
model_instance.credentials = model_config.credentials
|
||||
model_instance.parameters = model_config.parameters
|
||||
model_instance.stop = model_config.stop
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model=model_instance.model_name,
|
||||
credentials=model_instance.credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
if model_config is None:
|
||||
raise ValueError("Model schema not found for the provided model instance.")
|
||||
model_schema = model_config.model_schema
|
||||
|
||||
return model_instance, model_schema
|
||||
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
rest_tokens = self._calculate_rest_token(
|
||||
prompt_messages,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _calculate_rest_token(
|
||||
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> int:
|
||||
model_instance, model_schema = self._resolve_model_runtime(
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
model_parameters = model_instance.parameters
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
model_parameters.get(parameter_rule.name)
|
||||
or model_parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
|
||||
@ -252,7 +252,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
if memory:
|
||||
tmp_human_message = UserPromptMessage(content=prompt)
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config=model_config)
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(
|
||||
|
||||
@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI:
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None, # ty: ignore [invalid-argument-type]
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None, # ty: ignore [invalid-argument-type]
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None, # ty: ignore [invalid-argument-type]
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
|
||||
@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments]
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
search_iter = self._scope.search(
|
||||
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -27,14 +27,14 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -46,7 +46,7 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self):
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
|
||||
@ -35,7 +35,9 @@ class CacheEmbedding(Embeddings):
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@ -52,7 +54,7 @@ class CacheEmbedding(Embeddings):
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
self._model_instance.model_name, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
@ -87,7 +89,7 @@ class CacheEmbedding(Embeddings):
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model,
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
@ -114,7 +116,9 @@ class CacheEmbedding(Embeddings):
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@ -131,7 +135,7 @@ class CacheEmbedding(Embeddings):
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
self._model_instance.model_name, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
@ -168,7 +172,7 @@ class CacheEmbedding(Embeddings):
|
||||
file_id = multimodel_documents[i]["file_id"]
|
||||
if file_id not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model,
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
@ -190,7 +194,7 @@ class CacheEmbedding(Embeddings):
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
@ -233,7 +237,7 @@ class CacheEmbedding(Embeddings):
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
file_id = multimodel_document["file_id"]
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
|
||||
252
api/core/rag/index_processor/index_processor.py
Normal file
252
api/core/rag/index_processor/index_processor.py
Normal file
@ -0,0 +1,252 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
|
||||
from core.workflow.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .index_processor_factory import IndexProcessorFactory
|
||||
from .processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexProcessor:
|
||||
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview = index_processor.format_preview(chunks)
|
||||
data = Preview(
|
||||
chunk_structure=preview["chunk_structure"],
|
||||
total_segments=preview["total_segments"],
|
||||
preview=[],
|
||||
parent_mode=None,
|
||||
qa_preview=[],
|
||||
)
|
||||
if "parent_mode" in preview:
|
||||
data.parent_mode = preview["parent_mode"]
|
||||
|
||||
for item in preview["preview"]:
|
||||
if "content" in item and "child_chunks" in item:
|
||||
data.preview.append(
|
||||
PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None)
|
||||
)
|
||||
elif "question" in item and "answer" in item:
|
||||
data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"]))
|
||||
elif "content" in item:
|
||||
data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None))
|
||||
return data
|
||||
|
||||
def index_and_clean(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
summary_index_setting: dict | None = None,
|
||||
):
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise KnowledgeIndexNodeError(f"Document {document_id} not found.")
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
dataset_name_value = dataset.name
|
||||
document_name_value = document.name
|
||||
created_at_value = document.created_at
|
||||
if summary_index_setting is None:
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
index_node_ids = []
|
||||
|
||||
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||
if original_document_id:
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
indexing_start_at = time.perf_counter()
|
||||
# delete from vector index
|
||||
if index_node_ids:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if index_node_ids:
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
index_processor.index(dataset, document, chunks)
|
||||
indexing_end_at = time.perf_counter()
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||
document.indexing_status = "completed"
|
||||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
session.query(func.sum(DocumentSegment.word_count))
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
)
|
||||
.scalar()
|
||||
) or 0
|
||||
# Update need_summary based on dataset's summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable") is True:
|
||||
document.need_summary = True
|
||||
else:
|
||||
document.need_summary = False
|
||||
session.add(document)
|
||||
# update document segment status
|
||||
session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: "completed",
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": dataset_name_value,
|
||||
"batch": batch,
|
||||
"document_id": document_id,
|
||||
"document_name": document_name_value,
|
||||
"created_at": created_at_value.timestamp(),
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def get_preview_output(
|
||||
self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None
|
||||
) -> Preview:
|
||||
doc_language = None
|
||||
with session_factory.create_session() as session:
|
||||
if document_id:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
else:
|
||||
document = None
|
||||
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.")
|
||||
|
||||
if summary_index_setting is None:
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
|
||||
if document:
|
||||
doc_language = document.doc_language
|
||||
indexing_technique = dataset.indexing_technique
|
||||
tenant_id = dataset.tenant_id
|
||||
|
||||
preview_output = self.format_preview(chunk_structure, chunks)
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
if preview_output.preview is not None:
|
||||
chunk_count = len(preview_output.preview)
|
||||
logger.info(
|
||||
"Generating summaries for %s chunks in preview mode (dataset: %s)",
|
||||
chunk_count,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: PreviewItem) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
if preview_item.content is not None:
|
||||
# Set Flask application context in worker thread
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview_item.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item.summary = summary
|
||||
|
||||
else:
|
||||
summary, _ = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview_item.content if preview_item.content is not None else "",
|
||||
summary_index_setting=summary_index_setting,
|
||||
document_language=doc_language,
|
||||
)
|
||||
if summary:
|
||||
preview_item.summary = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output.preview))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output.preview))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item) for preview_item in preview_output.preview
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode, if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise KnowledgeIndexNodeError(error_summary)
|
||||
|
||||
completed_count = sum(1 for item in preview_output.preview if item.summary is not None)
|
||||
logger.info(
|
||||
"Completed summary generation for preview chunks: %s/%s succeeded",
|
||||
completed_count,
|
||||
len(preview_output.preview),
|
||||
)
|
||||
return preview_output
|
||||
@ -75,15 +75,15 @@ class BaseIndexProcessor(ABC):
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs import helper
|
||||
@ -115,7 +115,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
@ -130,7 +130,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
else:
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
@ -196,7 +196,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
documents: list[Any] = []
|
||||
all_multimodal_documents: list[Any] = []
|
||||
if isinstance(chunks, list):
|
||||
@ -469,12 +469,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if not isinstance(result, LLMResult):
|
||||
raise ValueError("Expected LLMResult when stream=False")
|
||||
|
||||
summary_content = getattr(result.message, "content", "")
|
||||
summary_content = result.message.get_text_content()
|
||||
usage = result.usage
|
||||
|
||||
# Deduct quota for summary generation (same as workflow nodes)
|
||||
try:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
except Exception as e:
|
||||
# Log but don't fail summary generation if quota deduction fails
|
||||
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
|
||||
|
||||
@ -126,7 +126,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
for document in documents:
|
||||
@ -139,7 +139,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
@ -272,7 +272,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
child_nodes.append(child_document)
|
||||
return child_nodes
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
parent_childs = ParentChildStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
for parent_child in parent_childs.parent_child_chunks:
|
||||
|
||||
@ -139,14 +139,14 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
multimodal_documents: list[AttachmentDocument] | None = None,
|
||||
with_keywords: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
if multimodal_documents and dataset.is_multimodal:
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None:
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
@ -206,7 +206,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
|
||||
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None:
|
||||
qa_chunks = QAStructureChunk.model_validate(chunks)
|
||||
documents = []
|
||||
for qa_chunk in qa_chunks.qa_chunks:
|
||||
|
||||
@ -38,7 +38,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
model=self.rerank_model_instance.model,
|
||||
model=self.rerank_model_instance.model_name,
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if not is_support_vision:
|
||||
|
||||
@ -248,19 +248,22 @@ class DatasetRetrieval:
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
ext_meta = item.metadata or {}
|
||||
title = ext_meta.get("title") or ""
|
||||
doc_id = ext_meta.get("document_id") or title
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id"),
|
||||
document_name=item.metadata.get("title"),
|
||||
dataset_id=ext_meta.get("dataset_id") or "",
|
||||
dataset_name=ext_meta.get("dataset_name") or "",
|
||||
document_id=str(doc_id),
|
||||
document_name=ext_meta.get("title") or "",
|
||||
data_source_type="external",
|
||||
retriever_from="workflow",
|
||||
score=item.metadata.get("score"),
|
||||
doc_metadata=item.metadata,
|
||||
score=float(ext_meta.get("score") or 0.0),
|
||||
doc_metadata=ext_meta,
|
||||
),
|
||||
title=item.metadata.get("title"),
|
||||
title=title,
|
||||
content=item.page_content,
|
||||
)
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
0
api/core/rag/summary_index/__init__.py
Normal file
0
api/core/rag/summary_index/__init__.py
Normal file
86
api/core/rag/summary_index/summary_index.py
Normal file
86
api/core/rag/summary_index/summary_index.py
Normal file
@ -0,0 +1,86 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummaryIndex:
|
||||
def generate_and_vectorize_summary(
|
||||
self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None
|
||||
) -> None:
|
||||
if is_preview:
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset or dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
if summary_index_setting is None:
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
if not document_id:
|
||||
return
|
||||
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
# Skip qa_model documents
|
||||
if document is None or document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
query = session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
if not segment_ids:
|
||||
return
|
||||
|
||||
existing_summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
completed_summary_segment_ids = {i.chunk_id for i in existing_summaries}
|
||||
# Preview mode should process segments that are MISSING completed summaries
|
||||
pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids]
|
||||
|
||||
# If all segments already have completed summaries, nothing to do in preview mode
|
||||
if not pending_segment_ids:
|
||||
return
|
||||
|
||||
max_workers = min(10, len(pending_segment_ids))
|
||||
|
||||
def process_segment(segment_id: str) -> None:
|
||||
"""Process a single segment in a thread with a fresh DB session."""
|
||||
with session_factory.create_session() as session:
|
||||
segment = session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
if segment is None:
|
||||
return
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to generate summary for segment %s",
|
||||
segment_id,
|
||||
)
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_segment, segment_id) for segment_id in pending_segment_ids]
|
||||
concurrent.futures.wait(futures)
|
||||
else:
|
||||
generate_summary_index_task.delay(dataset_id, document_id, None)
|
||||
@ -6,9 +6,9 @@ identity:
|
||||
zh_Hans: 网页抓取
|
||||
pt_BR: WebScraper
|
||||
description:
|
||||
en_US: Web Scrapper tool kit is used to scrape web
|
||||
en_US: Web Scraper tool kit is used to scrape web
|
||||
zh_Hans: 一个用于抓取网页的工具。
|
||||
pt_BR: Web Scrapper tool kit is used to scrape web
|
||||
pt_BR: Web Scraper tool kit is used to scrape web
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- productivity
|
||||
|
||||
@ -47,7 +47,7 @@ class ModelInvocationUtils:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
@ -2,7 +2,7 @@ import re
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
@ -14,6 +14,12 @@ from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParamet
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class InterfaceDict(TypedDict):
|
||||
path: str
|
||||
method: str
|
||||
operation: dict[str, Any]
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
@ -35,7 +41,7 @@ class ApiBasedToolSchemaParser:
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
interfaces: list[InterfaceDict] = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
from core.workflow.variables.input_entities import VariableEntity
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
|
||||
@ -5,7 +5,6 @@ from collections.abc import Mapping
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
@ -23,6 +22,7 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.variables import VariableBase
|
||||
from core.workflow.variables import VariableBase
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
|
||||
@ -7,12 +7,28 @@ Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Protocol, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
class RedisPipelineProtocol(Protocol):
|
||||
"""Minimal Redis pipeline contract used by the command channel."""
|
||||
|
||||
def lrange(self, name: str, start: int, end: int) -> Any: ...
|
||||
def delete(self, *names: str) -> Any: ...
|
||||
def execute(self) -> list[Any]: ...
|
||||
def rpush(self, name: str, *values: str) -> Any: ...
|
||||
def expire(self, name: str, time: int) -> Any: ...
|
||||
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
|
||||
def get(self, name: str) -> Any: ...
|
||||
|
||||
|
||||
class RedisClientProtocol(Protocol):
|
||||
"""Redis client contract required by the command channel."""
|
||||
|
||||
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
|
||||
|
||||
|
||||
@final
|
||||
@ -26,7 +42,7 @@ class RedisChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: "RedisClientWrapper",
|
||||
redis_client: RedisClientProtocol,
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.variables.variables import Variable
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
|
||||
@ -9,7 +9,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@ -77,13 +76,10 @@ class GraphEngine:
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
self._config = config
|
||||
@ -163,7 +159,6 @@ class GraphEngine:
|
||||
layers=self._layers,
|
||||
execution_context=execution_context,
|
||||
config=self._config,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@ -194,7 +189,6 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@ -314,7 +308,6 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
deferred_nodes: list[str] = []
|
||||
if resume:
|
||||
@ -348,7 +341,6 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
@ -3,13 +3,14 @@ GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Callers must provide a Redis client dependency from outside the workflow package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
@ -17,7 +18,6 @@ from core.workflow.graph_engine.entities.commands import (
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,8 +31,12 @@ class GraphEngineManager:
|
||||
by sending commands through Redis channels, without user validation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
_redis_client: RedisClientProtocol
|
||||
|
||||
def __init__(self, redis_client: RedisClientProtocol) -> None:
|
||||
self._redis_client = redis_client
|
||||
|
||||
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
@ -41,34 +45,31 @@ class GraphEngineManager:
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
GraphEngineManager._send_command(task_id, abort_command)
|
||||
self._send_command(task_id, abort_command)
|
||||
|
||||
@staticmethod
|
||||
def send_pause_command(task_id: str, reason: str | None = None) -> None:
|
||||
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
self._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
GraphEngineManager._send_command(task_id, update_command)
|
||||
self._send_command(task_id, update_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
channel = RedisChannel(self._redis_client, channel_key)
|
||||
|
||||
try:
|
||||
channel.send_command(command)
|
||||
|
||||
@ -44,7 +44,6 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -62,7 +61,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = stop_event
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@ -70,12 +69,14 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
|
||||
@ -42,7 +42,6 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
@ -63,16 +62,13 @@ class Worker(threading.Thread):
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._execution_context = execution_context
|
||||
self._stop_event = stop_event
|
||||
self._stop_event = threading.Event()
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@ -37,7 +37,6 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
config: GraphEngineConfig,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
@ -64,7 +63,6 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@ -135,7 +133,6 @@ class WorkerPool:
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
execution_context=self._execution_context,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@ -34,7 +34,6 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@ -53,6 +52,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayFileSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.variables import ArrayFileSegment, FileSegment, Segment
|
||||
|
||||
|
||||
class AnswerNode(Node[AnswerNodeData]):
|
||||
|
||||
@ -305,10 +305,6 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Find all extractor node configurations that have parent_node_id == self._node_id.
|
||||
@ -338,7 +334,6 @@ class Node(Generic[NodeDataT]):
|
||||
if not extractor_configs:
|
||||
return
|
||||
|
||||
# Use DifyNodeFactory to properly instantiate nodes with required dependencies
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
@ -352,23 +347,20 @@ class Node(Generic[NodeDataT]):
|
||||
try:
|
||||
nested_node = node_factory.create_node(config)
|
||||
except ValueError:
|
||||
# Skip nodes that cannot be created (e.g., unknown type)
|
||||
continue
|
||||
|
||||
# Execute and process nested node events
|
||||
for event in nested_node.run():
|
||||
# Tag event with parent node id for stream ordering and history tracking
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
event.in_parent_node_id = self._node_id
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
# Store nested node outputs in variable pool
|
||||
outputs: Mapping[str, Any] = event.node_run_result.outputs
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
if not isinstance(event, NodeRunStreamChunkEvent):
|
||||
yield event
|
||||
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@ -440,21 +432,6 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
|
||||
@ -1,18 +1,15 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.variables.segments import ArrayFileSegment
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
@ -25,12 +22,56 @@ if TYPE_CHECKING:
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class WorkflowCodeExecutor(Protocol):
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
language: CodeLanguage,
|
||||
code: str,
|
||||
inputs: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
def is_execution_error(self, error: Exception) -> bool: ...
|
||||
|
||||
|
||||
def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "code",
|
||||
"config": {
|
||||
"variables": [
|
||||
{"variable": "arg1", "value_selector": []},
|
||||
{"variable": "arg2", "value_selector": []},
|
||||
],
|
||||
"code_language": language,
|
||||
"code": code,
|
||||
"outputs": {"result": {"type": "string", "children": None}},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = {
|
||||
CodeLanguage.PYTHON3: dedent(
|
||||
"""
|
||||
def main(arg1: str, arg2: str):
|
||||
return {
|
||||
"result": arg1 + arg2,
|
||||
}
|
||||
"""
|
||||
),
|
||||
CodeLanguage.JAVASCRIPT: dedent(
|
||||
"""
|
||||
function main({arg1, arg2}) {
|
||||
return {
|
||||
result: arg1 + arg2
|
||||
}
|
||||
}
|
||||
"""
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
|
||||
Python3CodeProvider,
|
||||
JavascriptCodeProvider,
|
||||
)
|
||||
_limits: CodeNodeLimits
|
||||
|
||||
def __init__(
|
||||
@ -40,8 +81,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_executor: WorkflowCodeExecutor,
|
||||
code_limits: CodeNodeLimits,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -50,10 +90,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
|
||||
)
|
||||
self._code_executor: WorkflowCodeExecutor = code_executor
|
||||
self._limits = code_limits
|
||||
|
||||
@classmethod
|
||||
@ -67,15 +104,10 @@ class CodeNode(Node[CodeNodeData]):
|
||||
if filters:
|
||||
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
|
||||
|
||||
code_provider: type[CodeNodeProvider] = next(
|
||||
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
|
||||
)
|
||||
|
||||
return code_provider.get_default_config()
|
||||
|
||||
@classmethod
|
||||
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
|
||||
return cls._DEFAULT_CODE_PROVIDERS
|
||||
default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language)
|
||||
if default_code is None:
|
||||
raise CodeNodeError(f"Unsupported code language: {code_language}")
|
||||
return _build_default_config(language=code_language, code=default_code)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
@ -97,8 +129,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
# Run code
|
||||
try:
|
||||
_ = self._select_code_provider(code_language)
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
result = self._code_executor.execute(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
@ -106,19 +137,19 @@ class CodeNode(Node[CodeNodeData]):
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
except CodeNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
)
|
||||
except Exception as e:
|
||||
if not self._code_executor.is_execution_error(e):
|
||||
raise
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
|
||||
for provider in self._code_providers:
|
||||
if provider.is_accept_language(code_language):
|
||||
return provider
|
||||
raise CodeNodeError(f"Unsupported code language: {code_language}")
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
|
||||
@ -1,11 +1,18 @@
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
|
||||
class CodeLanguage(StrEnum):
|
||||
PYTHON3 = "python3"
|
||||
JINJA2 = "jinja2"
|
||||
JAVASCRIPT = "javascript"
|
||||
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
|
||||
@ -1,40 +1,26 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.file import File
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from core.workflow.repositories.datasource_manager_protocol import (
|
||||
DatasourceManagerProtocol,
|
||||
DatasourceParameter,
|
||||
OnlineDriveDownloadFileParam,
|
||||
)
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
from .exc import DatasourceNodeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class DatasourceNode(Node[DatasourceNodeData]):
|
||||
@ -45,6 +31,22 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.datasource_manager = datasource_manager
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
@ -52,84 +54,69 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segement:
|
||||
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segment:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segement:
|
||||
datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None
|
||||
datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segment:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info_value = datasource_info_segement.value
|
||||
datasource_info_value = datasource_info_segment.value
|
||||
if not isinstance(datasource_info_value, dict):
|
||||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_type = DatasourceProviderType.value_of(datasource_type)
|
||||
provider_id = f"{node_data.plugin_id}/{node_data.provider_name}"
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_info["icon"] = self.datasource_manager.get_icon_url(
|
||||
provider_id=provider_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=datasource_type,
|
||||
datasource_type=datasource_type.value,
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=datasource_info.get("credential_id", ""),
|
||||
)
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE:
|
||||
# Build typed request objects
|
||||
datasource_parameters = None
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_parameters = DatasourceParameter(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
)
|
||||
)
|
||||
yield from self._transform_message(
|
||||
messages=online_document_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
)
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
|
||||
online_drive_request = None
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
online_drive_request = OnlineDriveDownloadFileParam(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket", ""),
|
||||
)
|
||||
)
|
||||
yield from self._transform_datasource_file_message(
|
||||
messages=online_drive_result,
|
||||
|
||||
credential_id = datasource_info.get("credential_id", "")
|
||||
|
||||
yield from self.datasource_manager.stream_node_events(
|
||||
node_id=self._node_id,
|
||||
user_id=self.user_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
datasource_type=datasource_type.value,
|
||||
provider_id=provider_id,
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=credential_id,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
variable_pool=variable_pool,
|
||||
datasource_type=datasource_type,
|
||||
datasource_param=datasource_parameters,
|
||||
online_drive_request=online_drive_request,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
yield StreamCompletedEvent(
|
||||
@ -147,23 +134,9 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
file_info = self.datasource_manager.get_upload_file_by_id(
|
||||
file_id=related_id, tenant_id=self.tenant_id
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
@ -201,55 +174,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
datasource_parameters: Sequence[DatasourceParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: DatasourceNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@ -287,206 +211,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
return result
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
match message.type:
|
||||
case (
|
||||
DatasourceMessage.MessageType.IMAGE_LINK
|
||||
| DatasourceMessage.MessageType.BINARY_LINK
|
||||
| DatasourceMessage.MessageType.IMAGE
|
||||
):
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
case DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
case DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
case DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
case DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case (
|
||||
DatasourceMessage.MessageType.BLOB_CHUNK
|
||||
| DatasourceMessage.MessageType.LOG
|
||||
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
|
||||
):
|
||||
pass
|
||||
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={**variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _transform_datasource_file_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
file = None
|
||||
for message in message_stream:
|
||||
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
if file:
|
||||
variable_pool.add([self._node_id, "file"], file)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@ -21,12 +21,12 @@ from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File, FileTransferMethod, file_manager
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.variables import ArrayFileSegment
|
||||
from core.workflow.variables.segments import ArrayStringSegment, FileSegment
|
||||
|
||||
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user