diff --git a/.agents/skills/backend-code-review/SKILL.md b/.agents/skills/backend-code-review/SKILL.md new file mode 100644 index 0000000000..35dc54173e --- /dev/null +++ b/.agents/skills/backend-code-review/SKILL.md @@ -0,0 +1,168 @@ +--- +name: backend-code-review +description: Review backend code for quality, security, maintainability, and best practices based on established checklist rules. Use when the user requests a review, analysis, or improvement of backend files (e.g., `.py`) under the `api/` directory. Do NOT use for frontend files (e.g., `.tsx`, `.ts`, `.js`). Supports pending-change review, code snippets review, and file-focused review. +--- + +# Backend Code Review + +## When to use this skill + +Use this skill whenever the user asks to **review, analyze, or improve** backend code (e.g., `.py`) under the `api/` directory. Supports the following review modes: + +- **Pending-change review**: when the user asks to review current changes (inspect staged/working-tree files slated for commit to get the changes). +- **Code snippets review**: when the user pastes code snippets (e.g., a function/class/module excerpt) into the chat and asks for a review. +- **File-focused review**: when the user points to specific files and asks for a review of those files (one file or a small, explicit set of files, e.g., `api/...`, `api/app.py`). + +Do NOT use this skill when: + +- The request is about frontend code or UI (e.g., `.tsx`, `.ts`, `.js`, `web/`). +- The user is not asking for a review/analysis/improvement of backend code. +- The scope is not under `api/` (unless the user explicitly asks to review backend-related changes outside `api/`). + +## How to use this skill + +Follow these steps when using this skill: + +1. **Identify the review mode** (pending-change vs snippet vs file-focused) based on the user’s input. Keep the scope tight: review only what the user provided or explicitly referenced. +2. Follow the rules defined in **Checklist** to perform the review. If no Checklist rule matches, apply **General Review Rules** as a fallback to perform the best-effort review. +3. Compose the final output strictly follow the **Required Output Format**. + +Notes when using this skill: +- Always include actionable fixes or suggestions (including possible code snippets). +- Use best-effort `File:Line` references when a file path and line numbers are available; otherwise, use the most specific identifier you can. + +## Checklist + +- db schema design: if the review scope includes code/files under `api/models/` or `api/migrations/`, follow [references/db-schema-rule.md](references/db-schema-rule.md) to perform the review +- architecture: if the review scope involves controller/service/core-domain/libs/model layering, dependency direction, or moving responsibilities across modules, follow [references/architecture-rule.md](references/architecture-rule.md) to perform the review +- repositories abstraction: if the review scope contains table/model operations (e.g., `select(...)`, `session.execute(...)`, joins, CRUD) and is not under `api/repositories`, `api/core/repositories`, or `api/extensions/*/repositories/`, follow [references/repositories-rule.md](references/repositories-rule.md) to perform the review +- sqlalchemy patterns: if the review scope involves SQLAlchemy session/query usage, db transaction/crud usage, or raw SQL usage, follow [references/sqlalchemy-rule.md](references/sqlalchemy-rule.md) to perform the review + +## General Review Rules + +### 1. Security Review + +Check for: +- SQL injection vulnerabilities +- Server-Side Request Forgery (SSRF) +- Command injection +- Insecure deserialization +- Hardcoded secrets/credentials +- Improper authentication/authorization +- Insecure direct object references + +### 2. Performance Review + +Check for: +- N+1 queries +- Missing database indexes +- Memory leaks +- Blocking operations in async code +- Missing caching opportunities + +### 3. Code Quality Review + +Check for: +- Code forward compatibility +- Code duplication (DRY violations) +- Functions doing too much (SRP violations) +- Deep nesting / complex conditionals +- Magic numbers/strings +- Poor naming +- Missing error handling +- Incomplete type coverage + +### 4. Testing Review + +Check for: +- Missing test coverage for new code +- Tests that don't test behavior +- Flaky test patterns +- Missing edge cases + +## Required Output Format + +When this skill invoked, the response must exactly follow one of the two templates: + +### Template A (any findings) + +```markdown +# Code Review Summary + +Found critical issues need to be fixed: + +## πŸ”΄ Critical (Must Fix) + +### 1. + +FilePath: line + + +#### Explanation + + + +#### Suggested Fix + +1. +2. (optional, omit if not applicable) + +--- +... (repeat for each critical issue) ... + +Found suggestions for improvement: + +## 🟑 Suggestions (Should Consider) + +### 1. + +FilePath: line + + +#### Explanation + + + +#### Suggested Fix + +1. +2. (optional, omit if not applicable) + +--- +... (repeat for each suggestion) ... + +Found optional nits: + +## 🟒 Nits (Optional) +### 1. + +FilePath: line + + +#### Explanation + + + +#### Suggested Fix + +- + +--- +... (repeat for each nits) ... + +## βœ… What's Good + +- +``` + +- If there are no critical issues or suggestions or option nits or good points, just omit that section. +- If the issue number is more than 10, summarize as "Found 10+ critical issues/suggestions/optional nits" and only output the first 10 items. +- Don't compress the blank lines between sections; keep them as-is for readability. +- If there is any issue requires code changes, append a brief follow-up question to ask whether the user wants to apply the fix(es) after the structured output. For example: "Would you like me to use the Suggested fix(es) to address these issues?" + +### Template B (no issues) + +```markdown +## Code Review Summary +βœ… No issues found. +``` \ No newline at end of file diff --git a/.agents/skills/backend-code-review/references/architecture-rule.md b/.agents/skills/backend-code-review/references/architecture-rule.md new file mode 100644 index 0000000000..c3fd08bf03 --- /dev/null +++ b/.agents/skills/backend-code-review/references/architecture-rule.md @@ -0,0 +1,91 @@ +# Rule Catalog β€” Architecture + +## Scope +- Covers: controller/service/core-domain/libs/model layering, dependency direction, responsibility placement, observability-friendly flow. + +## Rules + +### Keep business logic out of controllers +- Category: maintainability +- Severity: critical +- Description: Controllers should parse input, call services, and return serialized responses. Business decisions inside controllers make behavior hard to reuse and test. +- Suggested fix: Move domain/business logic into the service or core/domain layer. Keep controller handlers thin and orchestration-focused. +- Example: + - Bad: + ```python + @bp.post("/apps//publish") + def publish_app(app_id: str): + payload = request.get_json() or {} + if payload.get("force") and current_user.role != "admin": + raise ValueError("only admin can force publish") + app = App.query.get(app_id) + app.status = "published" + db.session.commit() + return {"result": "ok"} + ``` + - Good: + ```python + @bp.post("/apps//publish") + def publish_app(app_id: str): + payload = PublishRequest.model_validate(request.get_json() or {}) + app_service.publish_app(app_id=app_id, force=payload.force, actor_id=current_user.id) + return {"result": "ok"} + ``` + +### Preserve layer dependency direction +- Category: best practices +- Severity: critical +- Description: Controllers may depend on services, and services may depend on core/domain abstractions. Reversing this direction (for example, core importing controller/web modules) creates cycles and leaks transport concerns into domain code. +- Suggested fix: Extract shared contracts into core/domain or service-level modules and make upper layers depend on lower, not the reverse. +- Example: + - Bad: + ```python + # core/policy/publish_policy.py + from controllers.console.app import request_context + + def can_publish() -> bool: + return request_context.current_user.is_admin + ``` + - Good: + ```python + # core/policy/publish_policy.py + def can_publish(role: str) -> bool: + return role == "admin" + + # service layer adapts web/user context to domain input + allowed = can_publish(role=current_user.role) + ``` + +### Keep libs business-agnostic +- Category: maintainability +- Severity: critical +- Description: Modules under `api/libs/` should remain reusable, business-agnostic building blocks. They must not encode product/domain-specific rules, workflow orchestration, or business decisions. +- Suggested fix: + - If business logic appears in `api/libs/`, extract it into the appropriate `services/` or `core/` module and keep `libs` focused on generic, cross-cutting helpers. + - Keep `libs` dependencies clean: avoid importing service/controller/domain-specific modules into `api/libs/`. +- Example: + - Bad: + ```python + # api/libs/conversation_filter.py + from services.conversation_service import ConversationService + + def should_archive_conversation(conversation, tenant_id: str) -> bool: + # Domain policy and service dependency are leaking into libs. + service = ConversationService() + if service.has_paid_plan(tenant_id): + return conversation.idle_days > 90 + return conversation.idle_days > 30 + ``` + - Good: + ```python + # api/libs/datetime_utils.py (business-agnostic helper) + def older_than_days(idle_days: int, threshold_days: int) -> bool: + return idle_days > threshold_days + + # services/conversation_service.py (business logic stays in service/core) + from libs.datetime_utils import older_than_days + + def should_archive_conversation(conversation, tenant_id: str) -> bool: + threshold_days = 90 if has_paid_plan(tenant_id) else 30 + return older_than_days(conversation.idle_days, threshold_days) + ``` \ No newline at end of file diff --git a/.agents/skills/backend-code-review/references/db-schema-rule.md b/.agents/skills/backend-code-review/references/db-schema-rule.md new file mode 100644 index 0000000000..8feae2596a --- /dev/null +++ b/.agents/skills/backend-code-review/references/db-schema-rule.md @@ -0,0 +1,157 @@ +# Rule Catalog β€” DB Schema Design + +## Scope +- Covers: model/base inheritance, schema boundaries in model properties, tenant-aware schema design, index redundancy checks, dialect portability in models, and cross-database compatibility in migrations. +- Does NOT cover: session lifecycle, transaction boundaries, and query execution patterns (handled by `sqlalchemy-rule.md`). + +## Rules + +### Do not query other tables inside `@property` +- Category: [maintainability, performance] +- Severity: critical +- Description: A model `@property` must not open sessions or query other tables. This hides dependencies across models, tightly couples schema objects to data access, and can cause N+1 query explosions when iterating collections. +- Suggested fix: + - Keep model properties pure and local to already-loaded fields. + - Move cross-table data fetching to service/repository methods. + - For list/batch reads, fetch required related data explicitly (join/preload/bulk query) before rendering derived values. +- Example: + - Bad: + ```python + class Conversation(TypeBase): + __tablename__ = "conversations" + + @property + def app_name(self) -> str: + with Session(db.engine, expire_on_commit=False) as session: + app = session.execute(select(App).where(App.id == self.app_id)).scalar_one() + return app.name + ``` + - Good: + ```python + class Conversation(TypeBase): + __tablename__ = "conversations" + + @property + def display_title(self) -> str: + return self.name or "Untitled" + + + # Service/repository layer performs explicit batch fetch for related App rows. + ``` + +### Prefer including `tenant_id` in model definitions +- Category: maintainability +- Severity: suggestion +- Description: In multi-tenant domains, include `tenant_id` in schema definitions whenever the entity belongs to tenant-owned data. This improves data isolation safety and keeps future partitioning/sharding strategies practical as data volume grows. +- Suggested fix: + - Add a `tenant_id` column and ensure related unique/index constraints include tenant dimension when applicable. + - Propagate `tenant_id` through service/repository contracts to keep access paths tenant-aware. + - Exception: if a table is explicitly designed as non-tenant-scoped global metadata, document that design decision clearly. +- Example: + - Bad: + ```python + from sqlalchemy.orm import Mapped + + class Dataset(TypeBase): + __tablename__ = "datasets" + id: Mapped[str] = mapped_column(StringUUID, primary_key=True) + name: Mapped[str] = mapped_column(sa.String(255), nullable=False) + ``` + - Good: + ```python + from sqlalchemy.orm import Mapped + + class Dataset(TypeBase): + __tablename__ = "datasets" + id: Mapped[str] = mapped_column(StringUUID, primary_key=True) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) + name: Mapped[str] = mapped_column(sa.String(255), nullable=False) + ``` + +### Detect and avoid duplicate/redundant indexes +- Category: performance +- Severity: suggestion +- Description: Review index definitions for leftmost-prefix redundancy. For example, index `(a, b, c)` can safely cover most lookups for `(a, b)`. Keeping both may increase write overhead and can mislead the optimizer into suboptimal execution plans. +- Suggested fix: + - Before adding an index, compare against existing composite indexes by leftmost-prefix rules. + - Drop or avoid creating redundant prefixes unless there is a proven query-pattern need. + - Apply the same review standard in both model `__table_args__` and migration index DDL. +- Example: + - Bad: + ```python + __table_args__ = ( + sa.Index("idx_msg_tenant_app", "tenant_id", "app_id"), + sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"), + ) + ``` + - Good: + ```python + __table_args__ = ( + # Keep the wider index unless profiling proves a dedicated short index is needed. + sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"), + ) + ``` + +### Avoid PostgreSQL-only dialect usage in models; wrap in `models.types` +- Category: maintainability +- Severity: critical +- Description: Model/schema definitions should avoid PostgreSQL-only constructs directly in business models. When database-specific behavior is required, encapsulate it in `api/models/types.py` using both PostgreSQL and MySQL dialect implementations, then consume that abstraction from model code. +- Suggested fix: + - Do not directly place dialect-only types/operators in model columns when a portable wrapper can be used. + - Add or extend wrappers in `models.types` (for example, `AdjustedJSON`, `LongText`, `BinaryData`) to normalize behavior across PostgreSQL and MySQL. +- Example: + - Bad: + ```python + from sqlalchemy.dialects.postgresql import JSONB + from sqlalchemy.orm import Mapped + + class ToolConfig(TypeBase): + __tablename__ = "tool_configs" + config: Mapped[dict] = mapped_column(JSONB, nullable=False) + ``` + - Good: + ```python + from sqlalchemy.orm import Mapped + + from models.types import AdjustedJSON + + class ToolConfig(TypeBase): + __tablename__ = "tool_configs" + config: Mapped[dict] = mapped_column(AdjustedJSON(), nullable=False) + ``` + +### Guard migration incompatibilities with dialect checks and shared types +- Category: maintainability +- Severity: critical +- Description: Migration scripts under `api/migrations/versions/` must account for PostgreSQL/MySQL incompatibilities explicitly. For dialect-sensitive DDL or defaults, branch on the active dialect (for example, `conn.dialect.name == "postgresql"`), and prefer reusable compatibility abstractions from `models.types` where applicable. +- Suggested fix: + - In migration upgrades/downgrades, bind connection and branch by dialect for incompatible SQL fragments. + - Reuse `models.types` wrappers in column definitions when that keeps behavior aligned with runtime models. + - Avoid one-dialect-only migration logic unless there is a documented, deliberate compatibility exception. +- Example: + - Bad: + ```python + with op.batch_alter_table("dataset_keyword_tables") as batch_op: + batch_op.add_column( + sa.Column( + "data_source_type", + sa.String(255), + server_default=sa.text("'database'::character varying"), + nullable=False, + ) + ) + ``` + - Good: + ```python + def _is_pg(conn) -> bool: + return conn.dialect.name == "postgresql" + + + conn = op.get_bind() + default_expr = sa.text("'database'::character varying") if _is_pg(conn) else sa.text("'database'") + + with op.batch_alter_table("dataset_keyword_tables") as batch_op: + batch_op.add_column( + sa.Column("data_source_type", sa.String(255), server_default=default_expr, nullable=False) + ) + ``` diff --git a/.agents/skills/backend-code-review/references/repositories-rule.md b/.agents/skills/backend-code-review/references/repositories-rule.md new file mode 100644 index 0000000000..555de98eb0 --- /dev/null +++ b/.agents/skills/backend-code-review/references/repositories-rule.md @@ -0,0 +1,61 @@ +# Rule Catalog - Repositories Abstraction + +## Scope +- Covers: when to reuse existing repository abstractions, when to introduce new repositories, and how to preserve dependency direction between service/core and infrastructure implementations. +- Does NOT cover: SQLAlchemy session lifecycle and query-shape specifics (handled by `sqlalchemy-rule.md`), and table schema/migration design (handled by `db-schema-rule.md`). + +## Rules + +### Introduce repositories abstraction +- Category: maintainability +- Severity: suggestion +- Description: If a table/model already has a repository abstraction, all reads/writes/queries for that table should use the existing repository. If no repository exists, introduce one only when complexity justifies it, such as large/high-volume tables, repeated complex query logic, or likely storage-strategy variation. +- Suggested fix: + - First check `api/repositories`, `api/core/repositories`, and `api/extensions/*/repositories/` to verify whether the table/model already has a repository abstraction. If it exists, route all operations through it and add missing repository methods instead of bypassing it with ad-hoc SQLAlchemy access. + - If no repository exists, add one only when complexity warrants it (for example, repeated complex queries, large data domains, or multiple storage strategies), while preserving dependency direction (service/core depends on abstraction; infra provides implementation). +- Example: + - Bad: + ```python + # Existing repository is ignored and service uses ad-hoc table queries. + class AppService: + def archive_app(self, app_id: str, tenant_id: str) -> None: + app = self.session.execute( + select(App).where(App.id == app_id, App.tenant_id == tenant_id) + ).scalar_one() + app.archived = True + self.session.commit() + ``` + - Good: + ```python + # Case A: Existing repository must be reused for all table operations. + class AppService: + def archive_app(self, app_id: str, tenant_id: str) -> None: + app = self.app_repo.get_by_id(app_id=app_id, tenant_id=tenant_id) + app.archived = True + self.app_repo.save(app) + + # If the query is missing, extend the existing abstraction. + active_apps = self.app_repo.list_active_for_tenant(tenant_id=tenant_id) + ``` + - Bad: + ```python + # No repository exists, but large-domain query logic is scattered in service code. + class ConversationService: + def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: + ... + # many filters/joins/pagination variants duplicated across services + ``` + - Good: + ```python + # Case B: Introduce repository for large/complex domains or storage variation. + class ConversationRepository(Protocol): + def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: ... + + class SqlAlchemyConversationRepository: + def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: + ... + + class ConversationService: + def __init__(self, conversation_repo: ConversationRepository): + self.conversation_repo = conversation_repo + ``` diff --git a/.agents/skills/backend-code-review/references/sqlalchemy-rule.md b/.agents/skills/backend-code-review/references/sqlalchemy-rule.md new file mode 100644 index 0000000000..cda3a5dc98 --- /dev/null +++ b/.agents/skills/backend-code-review/references/sqlalchemy-rule.md @@ -0,0 +1,139 @@ +# Rule Catalog β€” SQLAlchemy Patterns + +## Scope +- Covers: SQLAlchemy session and transaction lifecycle, query construction, tenant scoping, raw SQL boundaries, and write-path concurrency safeguards. +- Does NOT cover: table/model schema and migration design details (handled by `db-schema-rule.md`). + +## Rules + +### Use Session context manager with explicit transaction control behavior +- Category: best practices +- Severity: critical +- Description: Session and transaction lifecycle must be explicit and bounded on write paths. Missing commits can silently drop intended updates, while ad-hoc or long-lived transactions increase contention, lock duration, and deadlock risk. +- Suggested fix: + - Use **explicit `session.commit()`** after completing a related write unit. + - Or use **`session.begin()` context manager** for automatic commit/rollback on a scoped block. + - Keep transaction windows short: avoid network I/O, heavy computation, or unrelated work inside the transaction. +- Example: + - Bad: + ```python + # Missing commit: write may never be persisted. + with Session(db.engine, expire_on_commit=False) as session: + run = session.get(WorkflowRun, run_id) + run.status = "cancelled" + + # Long transaction: external I/O inside a DB transaction. + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + run = session.get(WorkflowRun, run_id) + run.status = "cancelled" + call_external_api() + ``` + - Good: + ```python + # Option 1: explicit commit. + with Session(db.engine, expire_on_commit=False) as session: + run = session.get(WorkflowRun, run_id) + run.status = "cancelled" + session.commit() + + # Option 2: scoped transaction with automatic commit/rollback. + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + run = session.get(WorkflowRun, run_id) + run.status = "cancelled" + + # Keep non-DB work outside transaction scope. + call_external_api() + ``` + +### Enforce tenant_id scoping on shared-resource queries +- Category: security +- Severity: critical +- Description: Reads and writes against shared tables must be scoped by `tenant_id` to prevent cross-tenant data leakage or corruption. +- Suggested fix: Add `tenant_id` predicate to all tenant-owned entity queries and propagate tenant context through service/repository interfaces. +- Example: + - Bad: + ```python + stmt = select(Workflow).where(Workflow.id == workflow_id) + workflow = session.execute(stmt).scalar_one_or_none() + ``` + - Good: + ```python + stmt = select(Workflow).where( + Workflow.id == workflow_id, + Workflow.tenant_id == tenant_id, + ) + workflow = session.execute(stmt).scalar_one_or_none() + ``` + +### Prefer SQLAlchemy expressions over raw SQL by default +- Category: maintainability +- Severity: suggestion +- Description: Raw SQL should be exceptional. ORM/Core expressions are easier to evolve, safer to compose, and more consistent with the codebase. +- Suggested fix: Rewrite straightforward raw SQL into SQLAlchemy `select/update/delete` expressions; keep raw SQL only when required by clear technical constraints. +- Example: + - Bad: + ```python + row = session.execute( + text("SELECT * FROM workflows WHERE id = :id AND tenant_id = :tenant_id"), + {"id": workflow_id, "tenant_id": tenant_id}, + ).first() + ``` + - Good: + ```python + stmt = select(Workflow).where( + Workflow.id == workflow_id, + Workflow.tenant_id == tenant_id, + ) + row = session.execute(stmt).scalar_one_or_none() + ``` + +### Protect write paths with concurrency safeguards +- Category: quality +- Severity: critical +- Description: Multi-writer paths without explicit concurrency control can silently overwrite data. Choose the safeguard based on contention level, lock scope, and throughput cost instead of defaulting to one strategy. +- Suggested fix: + - **Optimistic locking**: Use when contention is usually low and retries are acceptable. Add a version (or updated_at) guard in `WHERE` and treat `rowcount == 0` as a conflict. + - **Redis distributed lock**: Use when the critical section spans multiple steps/processes (or includes non-DB side effects) and you need cross-worker mutual exclusion. + - **SELECT ... FOR UPDATE**: Use when contention is high on the same rows and strict in-transaction serialization is required. Keep transactions short to reduce lock wait/deadlock risk. + - In all cases, scope by `tenant_id` and verify affected row counts for conditional writes. +- Example: + - Bad: + ```python + # No tenant scope, no conflict detection, and no lock on a contested write path. + session.execute(update(WorkflowRun).where(WorkflowRun.id == run_id).values(status="cancelled")) + session.commit() # silently overwrites concurrent updates + ``` + - Good: + ```python + # 1) Optimistic lock (low contention, retry on conflict) + result = session.execute( + update(WorkflowRun) + .where( + WorkflowRun.id == run_id, + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.version == expected_version, + ) + .values(status="cancelled", version=WorkflowRun.version + 1) + ) + if result.rowcount == 0: + raise WorkflowStateConflictError("stale version, retry") + + # 2) Redis distributed lock (cross-worker critical section) + lock_name = f"workflow_run_lock:{tenant_id}:{run_id}" + with redis_client.lock(lock_name, timeout=20): + session.execute( + update(WorkflowRun) + .where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id) + .values(status="cancelled") + ) + session.commit() + + # 3) Pessimistic lock with SELECT ... FOR UPDATE (high contention) + run = session.execute( + select(WorkflowRun) + .where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id) + .with_for_update() + ).scalar_one() + run.status = "cancelled" + session.commit() + ``` \ No newline at end of file diff --git a/.claude/skills/backend-code-review b/.claude/skills/backend-code-review new file mode 120000 index 0000000000..fb4ebdf8ee --- /dev/null +++ b/.claude/skills/backend-code-review @@ -0,0 +1 @@ +../../.agents/skills/backend-code-review \ No newline at end of file diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index b21aa17483..f9fbcba465 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -77,14 +77,7 @@ jobs: } const body = diff.trim() - ? `### Pyrefly Diff -
-base β†’ PR - -\`\`\`diff -${diff} -\`\`\` -
` + ? '### Pyrefly Diff\n
\nbase β†’ PR\n\n```diff\n' + diff + '\n```\n
' : '### Pyrefly Diff\nNo changes detected.'; await github.rest.issues.createComment({ diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index 0311187d44..2d22231144 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -74,14 +74,16 @@ jobs: } const body = diff.trim() - ? `### Pyrefly Diff -
-base β†’ PR - -\`\`\`diff -${diff} -\`\`\` -
` + ? [ + '### Pyrefly Diff', + '
', + 'base β†’ PR', + '', + '```diff', + diff, + '```', + '
', + ].join('\n') : '### Pyrefly Diff\nNo changes detected.'; await github.rest.issues.createComment({ diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 78d0b2af40..f50689636b 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -3,14 +3,22 @@ name: Web Tests on: workflow_call: +permissions: + contents: read + concurrency: group: web-tests-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: test: - name: Web Tests + name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }}) runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + shardIndex: [1, 2, 3, 4] + shardTotal: [4] defaults: run: shell: bash @@ -39,7 +47,58 @@ jobs: run: pnpm install --frozen-lockfile - name: Run tests - run: pnpm test:ci + run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage + + - name: Upload blob report + if: ${{ !cancelled() }} + uses: actions/upload-artifact@v6 + with: + name: blob-report-${{ matrix.shardIndex }} + path: web/.vitest-reports/* + include-hidden-files: true + retention-days: 1 + + merge-reports: + name: Merge Test Reports + if: ${{ !cancelled() }} + needs: [test] + runs-on: ubuntu-latest + defaults: + run: + shell: bash + working-directory: ./web + + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + persist-credentials: false + + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 24 + cache: pnpm + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Install dependencies + run: pnpm install --frozen-lockfile + + - name: Download blob reports + uses: actions/download-artifact@v6 + with: + path: web/.vitest-reports + pattern: blob-report-* + merge-multiple: true + + - name: Merge reports + run: pnpm vitest --merge-reports --coverage --silent=passed-only - name: Coverage Summary if: always() diff --git a/api/.importlinter b/api/.importlinter index c9364a0896..49cf70d61a 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -50,7 +50,6 @@ forbidden_modules = allow_indirect_imports = True ignore_imports = core.workflow.nodes.agent.agent_node -> extensions.ext_database - core.workflow.nodes.datasource.datasource_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database core.workflow.nodes.llm.file_saver -> extensions.ext_database core.workflow.nodes.llm.llm_utils -> extensions.ext_database @@ -106,15 +105,10 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> core.model_manager core.workflow.nodes.agent.agent_node -> core.provider_manager core.workflow.nodes.agent.agent_node -> core.tools.tool_manager - core.workflow.nodes.datasource.datasource_node -> models.model - core.workflow.nodes.datasource.datasource_node -> models.tools - core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy - core.workflow.nodes.http_request.node -> core.tools.tool_file_manager core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory core.workflow.nodes.llm.llm_utils -> configs - core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities core.workflow.nodes.llm.llm_utils -> core.model_manager core.workflow.nodes.llm.protocols -> core.model_manager core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model @@ -133,36 +127,21 @@ ignore_imports = core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities - core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model - core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities - core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform core.workflow.nodes.start.entities -> core.app.app_config.entities core.workflow.nodes.start.start_node -> core.app.app_config.entities core.workflow.workflow_entry -> core.app.apps.exc core.workflow.workflow_entry -> core.app.entities.app_invoke_entities core.workflow.workflow_entry -> core.app.workflow.node_factory - core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager - core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager - core.workflow.nodes.llm.llm_utils -> core.variables.segments - core.workflow.nodes.loop.entities -> core.variables.types core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer core.workflow.nodes.tool.tool_node -> models core.workflow.nodes.agent.agent_node -> models.model - core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider - core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider - core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider - core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor - core.workflow.nodes.datasource.datasource_node -> core.variables.variables - core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy - core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy core.workflow.nodes.llm.node -> core.helper.code_executor core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor @@ -190,57 +169,7 @@ ignore_imports = core.workflow.nodes.llm.file_saver -> core.tools.signature core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager core.workflow.nodes.tool.tool_node -> core.tools.errors - core.workflow.conversation_variable_updater -> core.variables - core.workflow.graph_engine.entities.commands -> core.variables.variables - core.workflow.nodes.agent.agent_node -> core.variables.segments - core.workflow.nodes.answer.answer_node -> core.variables - core.workflow.nodes.code.code_node -> core.variables.segments - core.workflow.nodes.code.code_node -> core.variables.types - core.workflow.nodes.code.entities -> core.variables.types - core.workflow.nodes.datasource.datasource_node -> core.variables.segments - core.workflow.nodes.document_extractor.node -> core.variables - core.workflow.nodes.document_extractor.node -> core.variables.segments - core.workflow.nodes.http_request.executor -> core.variables.segments - core.workflow.nodes.http_request.node -> core.variables.segments - core.workflow.nodes.human_input.entities -> core.variables.consts - core.workflow.nodes.iteration.iteration_node -> core.variables - core.workflow.nodes.iteration.iteration_node -> core.variables.segments - core.workflow.nodes.iteration.iteration_node -> core.variables.variables - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments - core.workflow.nodes.list_operator.node -> core.variables - core.workflow.nodes.list_operator.node -> core.variables.segments - core.workflow.nodes.llm.node -> core.variables - core.workflow.nodes.loop.loop_node -> core.variables - core.workflow.nodes.parameter_extractor.entities -> core.variables.types - core.workflow.nodes.parameter_extractor.exc -> core.variables.types - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types - core.workflow.nodes.tool.tool_node -> core.variables.segments - core.workflow.nodes.tool.tool_node -> core.variables.variables - core.workflow.nodes.trigger_webhook.node -> core.variables.types - core.workflow.nodes.trigger_webhook.node -> core.variables.variables - core.workflow.nodes.variable_aggregator.entities -> core.variables.types - core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments - core.workflow.nodes.variable_assigner.common.helpers -> core.variables - core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts - core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types - core.workflow.nodes.variable_assigner.v1.node -> core.variables - core.workflow.nodes.variable_assigner.v2.helpers -> core.variables - core.workflow.nodes.variable_assigner.v2.node -> core.variables - core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts - core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments - core.workflow.runtime.read_only_wrappers -> core.variables.segments - core.workflow.runtime.variable_pool -> core.variables - core.workflow.runtime.variable_pool -> core.variables.consts - core.workflow.runtime.variable_pool -> core.variables.segments - core.workflow.runtime.variable_pool -> core.variables.variables - core.workflow.utils.condition.processor -> core.variables - core.workflow.utils.condition.processor -> core.variables.segments - core.workflow.variable_loader -> core.variables - core.workflow.variable_loader -> core.variables.consts - core.workflow.workflow_type_encoder -> core.variables core.workflow.nodes.agent.agent_node -> extensions.ext_database - core.workflow.nodes.datasource.datasource_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database core.workflow.nodes.llm.file_saver -> extensions.ext_database core.workflow.nodes.llm.llm_utils -> extensions.ext_database diff --git a/api/README.md b/api/README.md index b23edeab72..b647367046 100644 --- a/api/README.md +++ b/api/README.md @@ -42,7 +42,7 @@ The scripts resolve paths relative to their location, so you can run them from a 1. Set up your application by visiting `http://localhost:3000`. -1. Optional: start the worker service (async tasks, runs from `api`). +1. Start the worker service (async and scheduler tasks, runs from `api`). ```bash ./dev/start-worker @@ -54,86 +54,6 @@ The scripts resolve paths relative to their location, so you can run them from a ./dev/start-beat ``` -### Manual commands - -
-Show manual setup and run steps - -These commands assume you start from the repository root. - -1. Start the docker-compose stack. - - The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`. - - ```bash - cp docker/middleware.env.example docker/middleware.env - # Use mysql or another vector database profile if you are not using postgres/weaviate. - docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d - ``` - -1. Copy env files. - - ```bash - cp api/.env.example api/.env - cp web/.env.example web/.env.local - ``` - -1. Install UV if needed. - - ```bash - pip install uv - # Or on macOS - brew install uv - ``` - -1. Install API dependencies. - - ```bash - cd api - uv sync --group dev - ``` - -1. Install web dependencies. - - ```bash - cd web - pnpm install - cd .. - ``` - -1. Start backend (runs migrations first, in a new terminal). - - ```bash - cd api - uv run flask db upgrade - uv run flask run --host 0.0.0.0 --port=5001 --debug - ``` - -1. Start Dify [web](../web) service (in a new terminal). - - ```bash - cd web - pnpm dev:inspect - ``` - -1. Set up your application by visiting `http://localhost:3000`. - -1. Optional: start the worker service (async tasks, in a new terminal). - - ```bash - cd api - uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention - ``` - -1. Optional: start Celery Beat (scheduled tasks, in a new terminal). - - ```bash - cd api - uv run celery -A app.celery beat - ``` - -
- ### Environment notes > [!IMPORTANT] diff --git a/api/constants/pipeline_templates.json b/api/constants/pipeline_templates.json index 32b42769e3..ac63ac39d2 100644 --- a/api/constants/pipeline_templates.json +++ b/api/constants/pipeline_templates.json @@ -50,6 +50,22 @@ "chunk_structure": "qa_model", "language": "en-US" }, + { + "id": "103825d3-7018-43ae-bcf0-f3c001f3eb69", + "name": "Contextual Enrichment Using LLM", + "description": "This knowledge pipeline uses LLMs to extract content from images and tables in documents and automatically generate descriptive annotations for contextual enrichment.", + "icon": { + "icon_type": "image", + "icon": "e642577f-da15-4c03-81b9-c9dec9189a3c", + "icon_background": null, + "icon_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAKAAAACgCAYAAACLz2ctAAAAAXNSR0IArs4c6QAAAERlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAA6ABAAMAAAABAAEAAKACAAQAAAABAAAAoKADAAQAAAABAAAAoAAAAACn7BmJAAAP9UlEQVR4Ae2dTXPbxhnHdwFRr5ZN2b1kJraouk57i/IJrJx6jDPT9Fpnkrvj3DOOv0DsXDvJxLk2nUnSW09hPkGc6aWdOBEtpZNLE9Gy3iiSQJ//gg8DQnyFFiAAPjtDLbAA9uWPn5595VKrjLjtn/YqrZaq+L6quL5X9pQqO1qtI3u+0mXy8MFJxfihP1qrss/XQ+FFPtRK1UmreriMJkz/GqaVX8N1z1dPHdyvnZpP1+fmVG3jhTVzDden6SjP6brt7b1y21VbWnk3CawKAbWp9Fmo0s3VbKamffWYgKz5vv+t1s5jt62qGxtrPVAnrUwqAH63u7dF/4E3qaBbVCB8zjjHcZRDJs91XaXJpOGDMDgSx5zj2HWDMByz4/v5fBZ80lLhE3Y498jcsfO8Nt1DlYbvmXs9L/DbbY/uozqmjwOUSvvVtuN8+tKLa4/73GI1KDEAYek8x7vta/0a5XiLcw1Y5uZcAxpgK5VKXeD4HvHTUaDdbivA2Go1yW+rZrPVkzDBUSOk7//u2m8e9VyweGIdQAPenLpD/3LvcLsM0C0szBNs8wY+nIvLpgKA8PS0YWBkKwkQyUo8un517b7tXFsl4cnO/25p33lA7YoKMloqzanFxSXj2864xJe8Ao3GaRdGpAYQbVtEKwCS1au0Xf8TyuMWMirgQYXiOFjFw8PDcLvxC7ek79roSZ8bwO3dvTue77+P6hZV69LSElm9heKoLyXpKgCLeHx8zCBSb9m7e972YWwATVvPVfeoL/YOcjg/X1IrKyvd3mo313JQKAXQLgSEgBGO3v/DG9eu3I1byFgAosr1HP9zauttitWLK32+nzs5aRgQMfSDoRtnXr8ep0qeGMAOfF+ho4FxuosXV7vjdfmWVHI/qQKwhvv7z02VTCDVnJJ+dVIIJwIwDB/G8FZXLwh8k761gt0PCJ8/PzDjiHEgHBvAKHywfDKeVzCaYhYH1TAsIQazJ4VwLAAFvphvZoYeiwvh2YnVPqJ1OhwVVLti+foIJEGmNgQbYISG5Creqf85Ga7yKGlGAvj9zh5mNjbR4UCbT6rdUZLO7nWwwf0CMNNyvXuj1BhaBdPU2m2lnE8Q8aVLF6XDMUpNuW4UQMfk2bN9swKHqua7N9avPBwkzUAATbvP9b/BDMfy8rLMbgxSUML7KoBxwqOjI1yr07TdK4OGZwZWwTS3+wDwYRWLTK311VgChygAZjA7Rq7cbpp1An3v7gtgUPWqW2j3YW5XnCgQR4HQ1OzWk529W/3i6AsgLakyjUfAx6uS+z0sYaLAMAXQd2ADRt9PedCvV3wGwO939+7xNBuqX3GiwHkUQFWM5XnUnKu0HM8sXAnHdwZA+grVbdwA8ylOFLChABYlw5FFvBO1gj0Aou0H6wdi8REnCthQIMRTmazg7XCcPQBy229+XhaUhkWS4/MrELKC+JJa13UB3P5xb1Pafl1d5MCyArCC6JSQ28LXdDn6LoD09bzbCJSql6UR37YC3U6t521x3F0AtaNvIlCqX5ZGfNsK4Gu5cGQJDWs4NgCiZ0JLujYRIBYQKohLQgFsSMDVMPeGDYBtt72FBAW+JGSXOFkBwAcI4bA/EHwDoO9rY/0cJ7iIC+JEgSQUwHpB4/ygHWgAJDJfRiD2aREnCiSpAANodkajhDoAqgoS7bfzFMLFiQK2FGAjR7WxMXqdKjjogDCdthKTeESBqAKdTgiCK/jjUG8kOOjsxYdAcaJAUgoAQF5hhV1xndacVL9JiS3x9leArSC2ZHa03y7jNg7s/4iEigL2FOChGGIPAOoKosY2uOJEgTQUYGNHw39lB7vRI1HszyxOFEhDAQaQ0io7fqc3EgpMIw+SxgwrwJ0QRzvr3XpXAJxhIqZYdKp59TrSl2m4Kb6FGUuajR3trLvWtYAzpoEUd4oKcIeXhgQvCYBTfBGStFJzm//EWkDqiiw1qR6W1TC7r11JlIurX/6caPy5iJx+uUkd7SOrFYfgM8MwNBKYi7xLJoulgFTBxXqfuSuNAJi7V1asDM99+8fLpvYtly91VykUq4jDSzPtNpntNme0PLbjH67meFexf2C9Hmx8QMOAwVQcj82MF4XcJQrEVyDEmpmKk9Uw8bWUJ2Mo0ANgjOflEVHAmgLSCbEmpUQURwEBMI5q8ow1BQRAa1JKRHEUyAWAPx7Rj+I1afpGXOEUyAWAn+2cqI9/aBROfCkQLT/Iugiwfp/tNtRH3x+LFcz6y4qRv8wDCOu3a6pgX6xgjBec9UcyDSBbPxZRrCArURw/0wCy9WO595tiBVmLoviZBTBq/VhwsYKsRDH8zAIYtX4st1hBVqIYfiYBHGT9WHKxgqxE/v1MAjjI+rHcYgVZifz7mfo5pACsE/XRDycjlYUVhPvT1QV1dTmT/0cjyyA30LfisiBCFzwz2Ezf0BvD4ZkP/n2k/kbjhH++tiggjqFZFm+ZKoBxwIuKiPaigBhVJT/n+snOL8bkXL68llqubYA3KLMvUnU8iUVM+zsU0fQGlaPw4Yd1U8RULWCS4PELE4vISuTDT7X1DgCxC8OlUvLJ/pqWfOE+yyimagFRPb77h2VTRaLz8PfdU1po0Laqz8WSVm/9dlG9fX1J4VhcthVIFUCWIgkQ8wqe7e/tRtuYtuPnd3he/5dfglpwKgBy5m2AmFfwWINZ96cKIIsfBfFjGohGG26YE/CGqZOfa5kAkOViENFy++A/wUwHX4v6b1Eb793fL0WD5TxnCiTfHY0hCOAa1oF4cdlVb9AUnLj8K3AuAD/baSh8bDvA9zb1ZAe5N67J/O8gbfIWHrsKBnjvfnPQLS+gsOlgBbEoIdoWFOtnU+XpxxXLAkbhA4i2LeEgKyjWb/rQ2MzBxABG4ePMJAFhtC0o1o/VLo4/EYCD4GM5bEMYtYJi/Vjp4vhjAzgKPpbENoRsBcX6scLF8sfqhIwLH0sDCOFsdEzYCvq0lausfGaFi+OPBHBS+FgamxDCCj4bMTPC6YqfLwWGAhgXPpbAFoSwgviIK54CA9uA54WPpbLdJuR4xS+GAn0BtAUfSyQQshLiRxU4A6Bt+DhBgZCVED+sQA+AScHHCQqErIT4rEAXwKTh4wQFQlZCfChgesH/+G9DvfdDenswA0I4G+OEJiL5k1sFHAPfvw5TL4BYwtQlz2SCzntTgI+VEAhZidn1u23AaUkgEE5L+WykO3UAIYNAmA0YppGLTAAoEE7j1WcjzcwAKBBmA4i0c5EpAAXCtF//9NPLHIAC4fShSDMHmQRQIEwTgemmlVkABcLpgpFW6pkGUCBMC4PppZN5AAXC6cGRRsq5AFAgTAOF6aSRGwAFwukAknSquQJQIEwah/Tjzx2AAmH6kCSZYi4BFAiTRCLduHMLoECYLihJpUYA6uAna+j3O/LoZClX/t4afium4+oEoJ9rAFEQgZDfZz78MIB65a9PtinbFbV0USkn1zWyFfWT/l2N6O94WMl03iLx6QtwR/vIdU2Iy9vLK1h+BcCCvdC8FUcAzNsbK0J+u50QXcfvBX9FZdpaXV1VpdLQ3dqKUHQpQwYUaDZb6vnz58hJVSxgBl7ILGcBAJphmFDXeJb1kLKnrIDj+f4zpOmjayxOFEhBAc8LfiNaKy3DMCnoLUlEFOj2QSjcoZ2Xa7jueWIBoYO45BXg2tbzvaeY+zBtQM/rzs8lnwNJYaYVCPU36k5bd+aClQA401SkWHiubbV2ao7Wbg1pt1pBwzDFfEhSM6oAW0Bfq7oz1wragBw4o5pIsVNUoN0O+htzc7QYYWNjrYa0YRYFwhTfwgwnxVXwxgtrnWEYX6zgDPOQatG5qad99RgJB1NxOjhpNpupZkYSmz0FeBCaKuGnKH0AoO+bE6Zz9mSREqelQKvV6iTlhy2gX0Uo09m5QzxRwLoC7XZnGk47vwLott0qUoIFlI6Idc0lwpACWIoF57ZVFb6pgqknjNmQKuCTahiyiEtCAYYPHZAOc502IKVG8H2NRE9PT5NIW+IUBYithlHBVwFrOAk6IebIqcITAKGCuCQUYAvoec4jjr8L4I2ra1UKNNUw38g3iS8KnFeBRqNhJjuw+uqljTXTAUGcXQBxon3/S/gnJ8fwxIkC1hTgmtVX+n440h4AHTKNRGgdFlCsYFgmOT6PAswTrN/vrq09CsfVAyB6JrRE/0PcIFYwLJMcn0eBw8Pg11iJrU+j8RCUvW57e6/sOf43tFSmsry8pBYXF3tvkDNRYAIF0PY7PDxSsH7Xr13eiD7aYwFxEVbQ1/oujo+PT2RgGkKIi6UAll2BIbho248jPAMgLlA9/QV5pkd8cJD+j1lz5sTPtwJoxnWWXn0RbftxyfoCiItuW79JZpM6JE1qDwYU80PiiwKjFDg5aahG4xRVb90tBTVqv2cGAkhVcU35QZcZZpRXsfaLRMJEgbACQdUbDOVR1XsXC0/D18PHAwHETdfX1x5SI/BDzBFjLw+BMCydHPdTAIyAFbOohdgZVPXys2Qhh7tOr/gr6hVvuq6rLl5cVVqPfGx4pHK1kAoAuv19GKo2TWqox9fXL78yqqBDLSAeRq/Y8fTrFGENESMBQ/eomOX6TCnQAx8NuTjz+vVxBBjblJElrND4ICxhRSzhONLOzj1n4CvpV4e1+8LKjA0gHopCeOHCBeW6I41oOD05LpgCaPMdHBwE1S4s3wTwQYqJAMQDYQgd2tgDG1sKhFBm9hx3ODDWRyBNDB8UmxhAPNSB8HN0TNAhWVpalCk7CDNDDuN8x8fHpj+ADgfafONWu2GZYgHIETx5+vND6hLfwfnCwjxBuCTWkMUpqI/2HhYXnJ52vsJLQy2u57yPzmqcIp8LQCT4ZGfvtlb+A9raqIwqGdZwYWEhTl7kmYwr0GP1aIaDVrfcv7F+5eF5sn1uAJE4quS2qx7QlPMtnAPElZUV2fQcYhTAYT0f5nVDa0SrNL32ZpwqNyqHFQA5UmMNff8ehmoQhl335+fnxSKyQDnzo+ARLDVMrXUWq1gpjVUAOUffPf35fUfpvzCIsIgBjAtiFVmkDPpo3+Fruc3mqVlIgHM4gsQsVJ7znIdx23qDipsIgJxY1CJyOGDEYPYc7c/lOPBdviR+SgoALnyw2gkzXPj02Zigqn39peOpR7bB42ImCiAnsv3j3iaNGVFnRd/E0A2Hh31YSYwnYlgHx/D5A0jZBdd7s8338T2z4DNA0bJibA4O+zCzBeOt93DOkPEWadHn6bxK931NL6Ha+aZkn1vsBfW+SXvxDoyJOixl6rBskUAYQ3yZxpAqg6AcGIlcsKMAtuXDzmjYnEo7VWyXkZSlG5Th1AEclJHtn/YqtHFShYAsA0pPeWXawn8d91PDt0KecbiOIR8+h0/G8kxY+HoRj+nF1cmg1c+UTQd7PVJ4nYbHzHXaf/6po5x6m7bEJa1q2JnURg/2TNoxAv4PoGedQHqhulIAAAAASUVORK5CYII=" + }, + "copyright": "Copyright 2023 Dify", + "privacy_policy": "https://dify.ai\n", + "position": 4, + "chunk_structure": "hierarchical_model", + "language": "en-US" + }, { "id": "982d1788-837a-40c8-b7de-d37b09a9b2bc", "name": "Convert to Markdown", @@ -81,6 +97,22 @@ "position": 6, "chunk_structure": "qa_model", "language": "en-US" + }, + { + "id": "629cb5b8-490a-48bc-808b-ffc13085cb4f", + "name": "Complex PDF with Images & Tables", + "description": "This Knowledge Pipeline extracts images and tables from complex PDF documents for downstream processing.", + "icon": { + "icon_type": "image", + "icon": "87426868-91d6-4774-a535-5fd4595a77b3", + "icon_background": null, + "icon_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAKAAAACgCAYAAACLz2ctAAAAAXNSR0IArs4c6QAAAERlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAA6ABAAMAAAABAAEAAKACAAQAAAABAAAAoKADAAQAAAABAAAAoAAAAACn7BmJAAARwElEQVR4Ae1dvXPcxhVfLMAP0RR1pL7MGVu8G7sXXdszotNYne1x6kgpktZSiiRNIrtMilgqnNZSb4/lzm4i5i8w1TvDE+UZyZIlnihKOvIAbN5v7/aIw93xPvBBHPDezBHYBbC7+O2Pb9++/YAlMiIPHjwoO65btpQqK6VKVKySsqwV9fQpSliy6IcTubhYxrFTrJJqXe+Mz2+I8KgJoeh3IIRBTW1vt+MoXLWWlgRheo/uqlmWVSVMa67jVJeXl6sHTx7dGb1HurK9uVnybHtNKXFBWAKEW1XCKvcrhb+tCdi+LBeX2ud80o3AaHipDUGkFErdJXJu2J63vliptAncnXr8MakQ8PH9+2tU9Av0omtCCZx3iZSSsLCE49j6iHPE+U+fCEnnCEOmTp/uehbXzPWuizmNoFaC4CQdFxCE3V9/bcd4vk8txpLwW/f6FPZ9RT8c/fZ9nSdESmGtK1veOvPGG3SerCRGQGg6V8rLxIwPg6QDUWzb1kTDcXrKaROu16v6T550RMuTJzvCHOhEYBS8PM8TIGmj4QrX9ejndiRG5Kj6lvj8zLlzNzsuxBiInYCaeI7zqeWrK8YuA+lmZqbF9PSUcIh0o2irUQCNEZeJTSoqXg0i4d7evial0ZIgopLWzdNvvvl53MDESsBfNrc+sqX6wth0juOIublZMUXHcSUqoOPmO6nPxYkXiFinn9GMIGLcGjEWApLWK7u2/ZVpauMgniFAnICaNPN8TAIvaMXd3ZcHdqMlbjve1NXFSvSetIxaGU/u3//Uk/aPIB+a1rm5Y+LEwnwkrRe1TPx8vAigBVssLYj51+Z0x5Dq+iNXNn58tLV1OWpOYxMQtt7jra0vqFd1HbYe7DsU8tjsTNQy8fMZRQB2PJQLjiQlS4mvwIEoxR2rCdZNrpTfUnd9FVrv2LHZxIiXRJMSBbCsP5sWXvX6nnj1qq5dPOQQ33D86Y/HaZJH1oAgnyflHZAPfrrSieOJkS/rlV3k8s1SS3eC6h4cABc82bizvfmgPComIxHQkA+9XPjwoI6bBRg1W74/Dwig7sEBuNbIDCPFNDoJhyYgky8PlIn/HUDChQgkHIqAvcg3ijM5/tfmFLOEALgwLgmHIiANqX0bbHaZfFmq/myUJUxCV+5/S4qrNKh0AwnY7GY3OxwLx18baRhtUOZ8PV8IgITHiSOmY0KDE9cGveGhBHy0SY5GJa4gYe5wDIKSrwMB0zHBDCZw5+G9e1cOQ6YvAWH3kX2pnYzw8zVZfVhSfI0RaCIAroAzEJp6cu0w90xfApL6pEkFogSvN49uNIHlv8MjAD8hRsdISq7d+Krfkz0J2Gp6PwKT51pM7pcAxzMC/RDQY8fNpnjtV5op1eu+ngSUUmnjEeTjprcXbBw3DALoO5imWJA516tX3EVAmt1yDS4XEK816DxMXnwPI9ATATTFmJ5H5lx5X8quDkkXAZXvX0ZK8/NzPRPkSEZgVAQwKRlCq34+DWvBDgLC9oP2w/yvKLOYdW78hxFoIQAuQQuSNNcJBZDpIKCx/bjpDSDEp7EgYLQgjWR8GEywTcBHmz/r9bls+wXh4fO4EIAWbDmn1x5v3l8z6bYJKKV3GZFTtEyShRFIAoHp5kxq4Ut/zaTfJqAS8gIiufk10PAxbgRajmloQs01pK+n5KNn4kp7GxEnlwZOYMBtqUl4inlqGeckoywt5MfODbXajp7G7/jeIrYB0RoQe7UAb+755oR1GX0NOKYlzZ6GGM5pAhIzVxFp074sLIxAkghg7x8I7VezhmPTBrSs8wiwBgQKLEkigLVEEIyM4Njs8iqLAtQNsdt9ElzLhGTJhskEIBNeCGxG9YLegaZpaaXXYlyzCcbqJhZGIEkEYAdCjAaUD2jiKSJ41gtQYEkaAd0RoYkuEOyKK2mMroyA3YrEOQsjkCQCRgs6dbcsaYtc7fizZFM1Jpkxp80IAAHTE7ZsVZbkgikjkptgoMCSBgJGAxL3SmiMmxqwZRymUQDOo9gIGAKCe9L0RgKRxUaH3z5xBExrS5xbaTv+9FSZxLPmDBiBTgSId9YKorLohO4sKofygoBRdp5Si20NmJeX4/fIPgLG40JEPMEEzH595bqEtF7Ool4wLUWa0F7wr+//JlMVdOrOfzrKY8p3/C9/FjMXL3ZcK2rADHrQHtPkiBa+dsOYdrmooCT93s//8U+x9/33SWczcelzE5xilYGEjY2NFHPMflZMwJTraOdvfxfuTz+lnGt2s3O8bb0URPheA+NxsZeU5/N1Qqp2d8Wzq38SJ774l3DefrvzYgZDSazJ0V/r3Hmu3xZTEHgoLuWKNyT0Hj5MOedsZBfo8OqhOCbgEdQLSLhDmrCIJOwg4BFgz1m2EAD5ikpCQwIHX9SGyJjWAydhM5jC5vFoSLhANqH9+uuZf8W4bHppNZd/xN/ryDyE2SugIWERm2MmYEb4aEgI27BIwgTMUG2DhDXqmBSJhEzADBEQRfHISV0kEjIBM0ZAQ0KMmBRBmIAZrWWMGWPsOO/CBMxwDWP2TN5JyATMMAFRNJBw98t/Z7yU4xePCTg+dqk9Wf/6a/Hy1q3U8kszIyZgmmhHyOvlzVu5JCETMAIp0n40jyRkAqbNooj55Y2ETMCIhDiKx0HCV19/cxRZx54nEzB2SNNJ8MWXX+ZikRMTMB2+JJJLHnyE/FmkRKhxkGh4nfDBFT4DAqwBmQdHigAT8Ejh58yZgMyBI0WAbcCY4Td7wcScbN/kJt3GZA3Yt2r5QhoIMAHTQJnz6IsAE7AvNHwhDQSYgGmgzHn0RYAJ2BcavpAGAkzANFDmPPoiwATsCw1fSAOBifcDTrofLI1KznIerAGzXDsFKBsTsACVnOVXZAJmuXYKUDYmYAEqOcuvyATMcu0UoGxMwAJUcpZfkQmY5dopQNkmzg846nw7m77Fge9xzH7wgZhaPT+wSodN35qf1+kibef8eTHz3rsD0+51w7D59Xq2V9yk+UUnjoC9QD8sDhs+4odNfqZWV8U8fTQwjs3AsYsptlDTn96ivVt2iZDT770n5i79Lpb0D3unPF0rVBMMstT+8MdEPpUFQoLkSD8vi8bTIHqhCAhAQRR8KiupHemRPhaN53lLtTiJOfFN8CCbp7FxV9RJM+398EMbN5Bkl3YfxffaBkm/9P2Hv2gSI2337t0uQmNLNeSD7wSPIv3yGyWNSbp34gk4CGx0PPCD3RfcY8/Yb7ALxxH5+lmBn+nY7H3/g04/qFnRJDtvvSWO/faTcbIoxDOFaYLnLl/SnZBgrYI0ccnMxQ9Er68doTnmz7P2R7kwBAQE6KEGpUFNZ5wCLdubhPndYjcqfoUiYPj7vMHmMiqQ5nmQEK6eoKC5hz3I0o1AoQgI53EaArsybFvWY2zu03iHtPIoFAHRIw5KWCMGr0U9n363c2QEznCWbgQKRcB6wBUDKOTZs92IxBRjescmubjtTZPupB9z74YxFQQXDNwiQZm9eDEYjPU8PNznD2kDjjo2POl+w1wTEIa/+9P/tH9Oj9kGKAaCTI85gSCQTN/TsL3JnZDeUE08AUfVGIAB5IC7hOXoESiUDQi4QT4MwYWbyLirIqzxwhox7vwmNb2J14CjAB/ndKxB+aLpD8qwhJ90my74zsOc556Akmy9GXKJYK5euGc6DEDj3hMefkuyxz1uGbPw3MQTMKsao/5N54dkZugfgKUbgcLZgN0QxB+DSQ7hYT5niOUA8Zck+yk6/vZTXUpfedkv7QSUEMQLTvtCkWdoPcqwNmDWX9F/8iSWIvq1Zzod1oCxwNlMBOTb6THbGlPBWHoj4FhC1JQQJaWUsCwKsYyFwCuy+fARwbD7Ze7Spdxov7GA6fEQuNaSmkOnNQowAQ0kQx4xJb9BEwwwHR/T8sPEQzJoeln7dQPaQUB7cVGQ7hOytCCk5BY5DNc4Iy2GfMf/+pdwchMXlidPxl9m3xfSniLWCTHxbpj40YmWIkY80OzyOpDhcGQCDofTwLtAvGOffKKJx8NuA+Fq38AEbEMx2glIBtfKFG3LgVEW5+239DjzaKkU826/1QlRQtWsx1tbd8gIXFtYmBdTDvOxmJRI960brit2dmiNjCXWudeRLvacWwgBEBBuGKH8tm8mdAsHGYHkEJDkk9FjIgHfTHK5ccqMACHgeb7GgdwwVW6CmRLpI3AwEiIkWIgSeOQcZGEE0kCg3QtW6t6BDRhgZRqF4DyKi0DA3KtJy7eanRAmYHEZkfKb+8YGtKyqVI5VRf6uy/MBU66HwmbXboI9qyZd160CiYBaLCww/OLpIOC3+hvurFOVy5VKFdkikn2B6VRA0XMxBFxeXm66YSyhqgCFxuaKjg2/f8IIuJ4x9dQGstKDv8qyaAM7UW40XDEzM51wEUZLPq41CKPlmp+7E5nPFwEe0wEhp989JKMd0Rb5YxA4YCdCLIxA/AhgIgKEiKc1YHMkxLLWEelxTxgwsCSIgPG20PqjAwLanreOPKEBuSOSIPqcNLn7mhrQcE7bgIuVSo3mBa6TK2bN9T0xJbM7LzBrNk3WOJVlm9k0v9Td3QDngF2zCcaZUv/FYX+/gQMLIxA7Anv1fZ0m+Vo01xA4IKAv1xGxt9e8CecsjECcCLQ1oO/fNOm2CXi68uY6pkhjRKR9o7mLj4xARASg2PRgB82+OlOp6A4IkmwTUKev1Hc4vnpZ10H+wwjEhUDdtKyW+DyYZgcBnaZqrEEDshYMwsTnURAAl9D7JduveubcuZvBtDoI2OyZqBu4gbVgECY+j4LA7u5L/Ti5+G6F0+kgIC6SFrxOY8JVsLZe3wvfz2FGYCQEgrbf2crKZ+GHuwgILSh96ypufPmqzo7pMGIcHhoBLPMAh7SEbD+TSBcBceFU5dxt0yPefdFUn+YBPjICwyIAM05PvbLE7bDtZ9LoSUBcpGG539Ohtt9ocFNs0OLj0AjAfNvb1z7lmutN6Ra118N9CagnqvpKd5mhRnnVXC/4OK4XAsGmV1ni6nJludrrPsT1JSAunq6sXKfJqjfgnMZeHkxCoMJyGALgCLgCzlCv90a/ptekcSgBcZPt+59h8Bht+fPnL7hTYpDjYxcCIB040hzxUBtnKitXum4KRQwkIHrFru9/DNeMR9O1nj0ndvM+MiEYOQjyPUMriSl95HD2/OmPh0FlIAGRCOxBUq3vMwmHgbR493STb+r9w+y+IEJDERAP9CIh24RBKIt5Dg50ar7hyQfEhiYgbg6TkDsmQKW4YjocB83uaOQDciMREA8YEpqOybNnz9lPCGAKJvDzoe5Nh8PzRycfIBuZgHgIJDy9svKOcdG8ePlKYMCZm2Sgk28xPV3UOc7hanlB/YNhbb4wOmMR0CRyamXlivKFHjGB1xtNMs+oNujk7witt13bERgdI6kJX12Fq6XSWt8xzhtHIiAyPFM5d5MWMr1DY8e3oY4xdoxC8nzCcaojm8+gLqFcjNbDPAHXn3oHAxVRS2xFTSD4/KPNrctCqmuWsMqIx6772Gkhym4L4VVevCoOyPaXOPEC8TChwCgT+Peoxbt6FpNVYpJYCWjK9Hjz3mdKikuGiPgEmCbj7PTIn4KIE1BTvjwfo+AFmw5rw7EyEqYUwi1Bc3tjV/jXozS3JrHgMRECmgzCGtHEg4y2Y2sySlsKx7bNpa5jFEC7EitAxLB46Q4EEWyf9gOCGwW7YuiNCQ5Ip7/jQSz8bpeWasRNPFMViRLQZPJo8+dV2vjjsiXFBXorOu8WaEmbfvhkLEipj3SOD2oj3oh96hRtbN1ZbNyLX5HEECj8zo3Hj3UUrmMjSLl0sukqoXPEYWsMfY3s9Z5C9p3wsEZcruuVkj1vii8y9Vrb3NwsHRf2mpJqlVhzntAo9yMlXtN80d28slxcMqd87IHAKHhhWz7sjKY8bBZurT8X3npSmq5HUXVU6gTsV5AHmw/KjnDLBEqJyFmm+0oEzop6+pQ6XQJhLdbiYonCJRPGkT43i3BHXPB6Ts9rhFUt/G7+9nYVcWS94VrNWloSrd3PatgPnLCqusKpjuu3Q9pxyv8BVb3XBNS3Vn0AAAAASUVORK5CYII=" + }, + "copyright": "Copyright 2023 Dify", + "privacy_policy": "https://dify.ai", + "position": 7, + "chunk_structure": "hierarchical_model", + "language": "en-US" } ] }, @@ -5153,7 +5185,7 @@ "language": "zh-Hans", "position": 5 }, - { + "103825d3-7018-43ae-bcf0-f3c001f3eb69": { "chunk_structure": "hierarchical_model", "description": "This knowledge pipeline uses LLMs to extract content from images and tables in documents and automatically generate descriptive annotations for contextual enrichment.", "export_data": "dependencies:\n- current_identifier: null\n type: marketplace\n value:\n marketplace_plugin_unique_identifier: langgenius\/jina:0.0.8@d3a6766fbb80890d73fea7ea04803f3e1702c6e6bd621aafb492b86222a193dd\n- current_identifier: null\n type: marketplace\n value:\n marketplace_plugin_unique_identifier: langgenius\/parentchild_chunker:0.0.7@ee9c253e7942436b4de0318200af97d98d094262f3c1a56edbe29dcb01fbc158\n- current_identifier: null\n type: marketplace\n value:\n marketplace_plugin_unique_identifier: langgenius\/mineru:0.5.0@ca04f2dceb4107e3adf24839756954b7c5bcb7045d035dbab5821595541c093d\n- current_identifier: null\n type: marketplace\n value:\n marketplace_plugin_unique_identifier: langgenius\/anthropic:0.2.0@a776815b091c81662b2b54295ef4b8a54b5533c2ec1c66c7c8f2feea724f3248\nkind: rag_pipeline\nrag_pipeline:\n description: ''\n icon: e642577f-da15-4c03-81b9-c9dec9189a3c\n icon_background: null\n icon_type: image\n icon_url: data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAAKAAAACgCAYAAACLz2ctAAAAAXNSR0IArs4c6QAAAERlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAA6ABAAMAAAABAAEAAKACAAQAAAABAAAAoKADAAQAAAABAAAAoAAAAACn7BmJAAAP9UlEQVR4Ae2dTXPbxhnHdwFRr5ZN2b1kJraouk57i\/IJrJx6jDPT9Fpnkrvj3DOOv0DsXDvJxLk2nUnSW09hPkGc6aWdOBEtpZNLE9Gy3iiSQJ\/\/gg8DQnyFFiAAPjtDLbAA9uWPn5595VKrjLjtn\/YqrZaq+L6quL5X9pQqO1qtI3u+0mXy8MFJxfihP1qrss\/XQ+FFPtRK1UmreriMJkz\/GqaVX8N1z1dPHdyvnZpP1+fmVG3jhTVzDden6SjP6brt7b1y21VbWnk3CawKAbWp9Fmo0s3VbKamffWYgKz5vv+t1s5jt62qGxtrPVAnrUwqAH63u7dF\/4E3qaBbVCB8zjjHcZRDJs91XaXJpOGDMDgSx5zj2HWDMByz4\/v5fBZ80lLhE3Y498jcsfO8Nt1DlYbvmXs9L\/DbbY\/uozqmjwOUSvvVtuN8+tKLa4\/73GI1KDEAYek8x7vta\/0a5XiLcw1Y5uZcAxpgK5VKXeD4HvHTUaDdbivA2Go1yW+rZrPVkzDBUSOk7\/\/u2m8e9VyweGIdQAPenLpD\/3LvcLsM0C0szBNs8wY+nIvLpgKA8PS0YWBkKwkQyUo8un517b7tXFsl4cnO\/25p33lA7YoKMloqzanFxSXj2864xJe8Ao3GaRdGpAYQbVtEKwCS1au0Xf8TyuMWMirgQYXiOFjFw8PDcLvxC7ek79roSZ8bwO3dvTue77+P6hZV69LSElm9heKoLyXpKgCLeHx8zCBSb9m7e972YWwATVvPVfeoL\/YOcjg\/X1IrKyvd3mo313JQKAXQLgSEgBGO3v\/DG9eu3I1byFgAosr1HP9zauttitWLK32+nzs5aRgQMfSDoRtnXr8ep0qeGMAOfF+ho4FxuosXV7vjdfmWVHI\/qQKwhvv7z02VTCDVnJJ+dVIIJwIwDB\/G8FZXLwh8k761gt0PCJ8\/PzDjiHEgHBvAKHywfDKeVzCaYhYH1TAsIQazJ4VwLAAFvphvZoYeiwvh2YnVPqJ1OhwVVLti+foIJEGmNgQbYISG5Creqf85Ga7yKGlGAvj9zh5mNjbR4UCbT6rdUZLO7nWwwf0CMNNyvXuj1BhaBdPU2m2lnE8Q8aVLF6XDMUpNuW4UQMfk2bN9swKHqua7N9avPBwkzUAATbvP9b\/BDMfy8rLMbgxSUML7KoBxwqOjI1yr07TdK4OGZwZWwTS3+wDwYRWLTK311VgChygAZjA7Rq7cbpp1An3v7gtgUPWqW2j3YW5XnCgQR4HQ1OzWk529W\/3i6AsgLakyjUfAx6uS+z0sYaLAMAXQd2ADRt9PedCvV3wGwO939+7xNBuqX3GiwHkUQFWM5XnUnKu0HM8sXAnHdwZA+grVbdwA8ylOFLChABYlw5FFvBO1gj0Aou0H6wdi8REnCthQIMRTmazg7XCcPQBy229+XhaUhkWS4\/MrELKC+JJa13UB3P5xb1Pafl1d5MCyArCC6JSQ28LXdDn6LoD09bzbCJSql6UR37YC3U6t521x3F0AtaNvIlCqX5ZGfNsK4Gu5cGQJDWs4NgCiZ0JLujYRIBYQKohLQgFsSMDVMPeGDYBtt72FBAW+JGSXOFkBwAcI4bA\/EHwDoO9rY\/0cJ7iIC+JEgSQUwHpB4\/ygHWgAJDJfRiD2aREnCiSpAANodkajhDoAqgoS7bfzFMLFiQK2FGAjR7WxMXqdKjjogDCdthKTeESBqAKdTgiCK\/jjUG8kOOjsxYdAcaJAUgoAQF5hhV1xndacVL9JiS3x9leArSC2ZHa03y7jNg7s\/4iEigL2FOChGGIPAOoKosY2uOJEgTQUYGNHw39lB7vRI1HszyxOFEhDAQaQ0io7fqc3EgpMIw+SxgwrwJ0QRzvr3XpXAJxhIqZYdKp59TrSl2m4Kb6FGUuajR3trLvWtYAzpoEUd4oKcIeXhgQvCYBTfBGStFJzm\/\/EWkDqiiw1qR6W1TC7r11JlIurX\/6caPy5iJx+uUkd7SOrFYfgM8MwNBKYi7xLJoulgFTBxXqfuSuNAJi7V1asDM99+8fLpvYtly91VykUq4jDSzPtNpntNme0PLbjH67meFexf2C9Hmx8QMOAwVQcj82MF4XcJQrEVyDEmpmKk9Uw8bWUJ2Mo0ANgjOflEVHAmgLSCbEmpUQURwEBMI5q8ow1BQRAa1JKRHEUyAWAPx7Rj+I1afpGXOEUyAWAn+2cqI9\/aBROfCkQLT\/Iugiwfp\/tNtRH3x+LFcz6y4qRv8wDCOu3a6pgX6xgjBec9UcyDSBbPxZRrCArURw\/0wCy9WO595tiBVmLoviZBTBq\/VhwsYKsRDH8zAIYtX4st1hBVqIYfiYBHGT9WHKxgqxE\/v1MAjjI+rHcYgVZifz7mfo5pACsE\/XRDycjlYUVhPvT1QV1dTmT\/0cjyyA30LfisiBCFzwz2Ezf0BvD4ZkP\/n2k\/kbjhH++tiggjqFZFm+ZKoBxwIuKiPaigBhVJT\/n+snOL8bkXL68llqubYA3KLMvUnU8iUVM+zsU0fQGlaPw4Yd1U8RULWCS4PELE4vISuTDT7X1DgCxC8OlUvLJ\/pqWfOE+yyimagFRPb77h2VTRaLz8PfdU1po0Laqz8WSVm\/9dlG9fX1J4VhcthVIFUCWIgkQ8wqe7e\/tRtuYtuPnd3he\/5dfglpwKgBy5m2AmFfwWINZ96cKIIsfBfFjGohGG26YE\/CGqZOfa5kAkOViENFy++A\/wUwHX4v6b1Eb793fL0WD5TxnCiTfHY0hCOAa1oF4cdlVb9AUnLj8K3AuAD\/baSh8bDvA9zb1ZAe5N67J\/O8gbfIWHrsKBnjvfnPQLS+gsOlgBbEoIdoWFOtnU+XpxxXLAkbhA4i2LeEgKyjWb\/rQ2MzBxABG4ePMJAFhtC0o1o\/VLo4\/EYCD4GM5bEMYtYJi\/Vjp4vhjAzgKPpbENoRsBcX6scLF8sfqhIwLH0sDCOFsdEzYCvq0lausfGaFi+OPBHBS+FgamxDCCj4bMTPC6YqfLwWGAhgXPpbAFoSwgviIK54CA9uA54WPpbLdJuR4xS+GAn0BtAUfSyQQshLiRxU4A6Bt+DhBgZCVED+sQA+AScHHCQqErIT4rEAXwKTh4wQFQlZCfChgesH\/+G9DvfdDenswA0I4G+OEJiL5k1sFHAPfvw5TL4BYwtQlz2SCzntTgI+VEAhZidn1u23AaUkgEE5L+WykO3UAIYNAmA0YppGLTAAoEE7j1WcjzcwAKBBmA4i0c5EpAAXCtF\/\/9NPLHIAC4fShSDMHmQRQIEwTgemmlVkABcLpgpFW6pkGUCBMC4PppZN5AAXC6cGRRsq5AFAgTAOF6aSRGwAFwukAknSquQJQIEwah\/Tjzx2AAmH6kCSZYi4BFAiTRCLduHMLoECYLihJpUYA6uAna+j3O\/LoZClX\/t4afium4+oEoJ9rAFEQgZDfZz78MIB65a9PtinbFbV0USkn1zWyFfWT\/l2N6O94WMl03iLx6QtwR\/vIdU2Iy9vLK1h+BcCCvdC8FUcAzNsbK0J+u50QXcfvBX9FZdpaXV1VpdLQ3dqKUHQpQwYUaDZb6vnz58hJVSxgBl7ILGcBAJphmFDXeJb1kLKnrIDj+f4zpOmjayxOFEhBAc8LfiNaKy3DMCnoLUlEFOj2QSjcoZ2Xa7jueWIBoYO45BXg2tbzvaeY+zBtQM\/rzs8lnwNJYaYVCPU36k5bd+aClQA401SkWHiubbV2ao7Wbg1pt1pBwzDFfEhSM6oAW0Bfq7oz1wragBw4o5pIsVNUoN0O+htzc7QYYWNjrYa0YRYFwhTfwgwnxVXwxgtrnWEYX6zgDPOQatG5qad99RgJB1NxOjhpNpupZkYSmz0FeBCaKuGnKH0AoO+bE6Zz9mSREqelQKvV6iTlhy2gX0Uo09m5QzxRwLoC7XZnGk47vwLott0qUoIFlI6Idc0lwpACWIoF57ZVFb6pgqknjNmQKuCTahiyiEtCAYYPHZAOc502IKVG8H2NRE9PT5NIW+IUBYithlHBVwFrOAk6IebIqcITAKGCuCQUYAvoec4jjr8L4I2ra1UKNNUw38g3iS8KnFeBRqNhJjuw+uqljTXTAUGcXQBxon3\/S\/gnJ8fwxIkC1hTgmtVX+n440h4AHTKNRGgdFlCsYFgmOT6PAswTrN\/vrq09CsfVAyB6JrRE\/0PcIFYwLJMcn0eBw8Pg11iJrU+j8RCUvW57e6\/sOf43tFSmsry8pBYXF3tvkDNRYAIF0PY7PDxSsH7Xr13eiD7aYwFxEVbQ1\/oujo+PT2RgGkKIi6UAll2BIbho248jPAMgLlA9\/QV5pkd8cJD+j1lz5sTPtwJoxnWWXn0RbftxyfoCiItuW79JZpM6JE1qDwYU80PiiwKjFDg5aahG4xRVb90tBTVqv2cGAkhVcU35QZcZZpRXsfaLRMJEgbACQdUbDOVR1XsXC0\/D18PHAwHETdfX1x5SI\/BDzBFjLw+BMCydHPdTAIyAFbOohdgZVPXys2Qhh7tOr\/gr6hVvuq6rLl5cVVqPfGx4pHK1kAoAuv19GKo2TWqox9fXL78yqqBDLSAeRq\/Y8fTrFGENESMBQ\/eomOX6TCnQAx8NuTjz+vVxBBjblJElrND4ICxhRSzhONLOzj1n4CvpV4e1+8LKjA0gHopCeOHCBeW6I41oOD05LpgCaPMdHBwE1S4s3wTwQYqJAMQDYQgd2tgDG1sKhFBm9hx3ODDWRyBNDB8UmxhAPNSB8HN0TNAhWVpalCk7CDNDDuN8x8fHpj+ADgfafONWu2GZYgHIETx5+vND6hLfwfnCwjxBuCTWkMUpqI\/2HhYXnJ52vsJLQy2u57yPzmqcIp8LQCT4ZGfvtlb+A9raqIwqGdZwYWEhTl7kmYwr0GP1aIaDVrfcv7F+5eF5sn1uAJE4quS2qx7QlPMtnAPElZUV2fQcYhTAYT0f5nVDa0SrNL32ZpwqNyqHFQA5UmMNff8ehmoQhl335+fnxSKyQDnzo+ARLDVMrXUWq1gpjVUAOUffPf35fUfpvzCIsIgBjAtiFVmkDPpo3+Fruc3mqVlIgHM4gsQsVJ7znIdx23qDipsIgJxY1CJyOGDEYPYc7c\/lOPBdviR+SgoALnyw2gkzXPj02Zigqn39peOpR7bB42ImCiAnsv3j3iaNGVFnRd\/E0A2Hh31YSYwnYlgHx\/D5A0jZBdd7s8338T2z4DNA0bJibA4O+zCzBeOt93DOkPEWadHn6bxK931NL6Ha+aZkn1vsBfW+SXvxDoyJOixl6rBskUAYQ3yZxpAqg6AcGIlcsKMAtuXDzmjYnEo7VWyXkZSlG5Th1AEclJHtn\/YqtHFShYAsA0pPeWXawn8d91PDt0KecbiOIR8+h0\/G8kxY+HoRj+nF1cmg1c+UTQd7PVJ4nYbHzHXaf\/6po5x6m7bEJa1q2JnURg\/2TNoxAv4PoGedQHqhulIAAAAASUVORK5CYII=\n name: Contextual Enrichment Using LLM\nversion: 0.1.0\nworkflow:\n conversation_variables: []\n environment_variables: []\n features: {}\n graph:\n edges:\n - data:\n isInLoop: false\n sourceType: tool\n targetType: knowledge-index\n id: 1751336942081-source-1750400198569-target\n selected: false\n source: '1751336942081'\n sourceHandle: source\n target: '1750400198569'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInLoop: false\n sourceType: llm\n targetType: tool\n id: 1758002850987-source-1751336942081-target\n source: '1758002850987'\n sourceHandle: source\n target: '1751336942081'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n isInLoop: false\n sourceType: datasource\n targetType: tool\n id: 1756915693835-source-1758027159239-target\n source: '1756915693835'\n sourceHandle: source\n target: '1758027159239'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInLoop: false\n sourceType: tool\n targetType: llm\n id: 1758027159239-source-1758002850987-target\n source: '1758027159239'\n sourceHandle: source\n target: '1758002850987'\n targetHandle: target\n type: custom\n zIndex: 0\n nodes:\n - data:\n chunk_structure: hierarchical_model\n embedding_model: jina-embeddings-v2-base-en\n embedding_model_provider: langgenius\/jina\/jina\n index_chunk_variable_selector:\n - '1751336942081'\n - result\n indexing_technique: high_quality\n keyword_number: 10\n retrieval_model:\n reranking_enable: true\n reranking_mode: reranking_model\n reranking_model:\n reranking_model_name: jina-reranker-v1-base-en\n reranking_provider_name: langgenius\/jina\/jina\n score_threshold: 0\n score_threshold_enabled: false\n search_method: hybrid_search\n top_k: 3\n weights: null\n selected: false\n title: Knowledge Base\n type: knowledge-index\n height: 114\n id: '1750400198569'\n position:\n x: 474.7618603027596\n y: 282\n positionAbsolute:\n x: 474.7618603027596\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n - data:\n author: TenTen\n desc: ''\n height: 458\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Currently\n we support 5 types of \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Data\n Sources\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\":\n File Upload, Text Input, Online Drive, Online Doc, and Web Crawler. Different\n types of Data Sources have different input and output types. The output\n of File Upload and Online Drive are files, while the output of Online Doc\n and WebCrawler are pages. You can find more Data Sources on our Marketplace.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"A\n Knowledge Pipeline can have multiple data sources. Each data source can\n be selected more than once with different settings. Each added data source\n is a tab on the add file interface. However, each time the user can only\n select one data source to import the file and trigger its subsequent processing.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 240\n height: 458\n id: '1751264451381'\n position:\n x: -893.2836123260277\n y: 378.2537898330178\n positionAbsolute:\n x: -893.2836123260277\n y: 378.2537898330178\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 240\n - data:\n author: TenTen\n desc: ''\n height: 260\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"A\n \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Knowledge\n Pipeline\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\n starts with Data Source as the starting node and ends with the knowledge\n base node. The general steps are: import documents from the data source\n \u2192 use extractor to extract document content \u2192 split and clean content into\n structured chunks \u2192 store in the knowledge base.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"The\n user input variables required by the Knowledge Pipeline node must be predefined\n and managed via the Input Field section located in the top-right corner\n of the orchestration canvas. It determines what input fields the end users\n will see and need to fill in when importing files to the knowledge base\n through this pipeline.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Unique\n Inputs: Input fields defined here are only available to the selected data\n source and its downstream nodes.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Global\n Inputs: These input fields are shared across all subsequent nodes after\n the data source and are typically set during the Process Documents step.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"For\n more information, see \",\"type\":\"text\",\"version\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"https:\/\/docs.dify.ai\/en\/guides\/knowledge-base\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"link\",\"version\":1,\"rel\":\"noreferrer\",\"target\":null,\"title\":null,\"url\":\"https:\/\/docs.dify.ai\/en\/guides\/knowledge-base\"},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\".\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 1182\n height: 260\n id: '1751266376760'\n position:\n x: -704.0614991386192\n y: -73.30453110517956\n positionAbsolute:\n x: -704.0614991386192\n y: -73.30453110517956\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 1182\n - data:\n author: TenTen\n desc: ''\n height: 304\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"MinerU\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\n is an advanced open-source document extractor designed specifically to convert\n complex, unstructured documents\u2014such as PDFs, Word files, and PPTs\u2014into\n high-quality, machine-readable formats like Markdown and JSON. MinerU addresses\n challenges in document parsing such as layout detection, formula recognition,\n and multi-language support, which are critical for generating high-quality\n training corpora for LLMs.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1,\"textFormat\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 240\n height: 304\n id: '1751266402561'\n position:\n x: -555.2228329530462\n y: 592.0458661166498\n positionAbsolute:\n x: -555.2228329530462\n y: 592.0458661166498\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 240\n - data:\n author: TenTen\n desc: ''\n height: 554\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Parent-Child\n Mode\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\n addresses the dilemma of context and precision by leveraging a two-tier\n hierarchical approach that effectively balances the trade-off between accurate\n matching and comprehensive contextual information in RAG systems. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Here\n is the essential mechanism of this structured, two-level information access:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"-\n Query Matching with Child Chunks: Small, focused pieces of information,\n often as concise as a single sentence within a paragraph, are used to match\n the user''s query. These child chunks enable precise and relevant initial\n retrieval.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"-\n Contextual Enrichment with Parent Chunks: Larger, encompassing sections\u2014such\n as a paragraph, a section, or even an entire document\u2014that include the matched\n child chunks are then retrieved. These parent chunks provide comprehensive\n context for the Language Model (LLM).\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1,\"textFormat\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 240\n height: 554\n id: '1751266447821'\n position:\n x: 153.2996965006646\n y: 378.2537898330178\n positionAbsolute:\n x: 153.2996965006646\n y: 378.2537898330178\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 240\n - data:\n author: TenTen\n desc: ''\n height: 411\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"The\n knowledge base provides two indexing methods:\u00a0\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"High-Quality\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\u00a0and\u00a0\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Economical\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\",\n each with different retrieval strategies. High-Quality mode uses embeddings\n for vectorization and supports vector, full-text, and hybrid retrieval,\n offering more accurate results but higher resource usage. Economical mode\n uses keyword-based inverted indexing with no token consumption but lower\n accuracy; upgrading to High-Quality is possible, but downgrading requires\n creating a new knowledge base.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"*\n Parent-Child Mode\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\u00a0and\u00a0\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Q&A\n Mode\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\u00a0only\n support the\u00a0\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"High-Quality\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\u00a0indexing\n method.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"start\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1,\"textFormat\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 240\n height: 411\n id: '1751266580099'\n position:\n x: 482.3389174180554\n y: 437.9839361130071\n positionAbsolute:\n x: 482.3389174180554\n y: 437.9839361130071\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 240\n - data:\n is_team_authorization: true\n output_schema:\n properties:\n result:\n description: Parent child chunks result\n items:\n type: object\n type: array\n type: object\n paramSchemas:\n - auto_generate: null\n default: null\n form: llm\n human_description:\n en_US: ''\n ja_JP: ''\n pt_BR: ''\n zh_Hans: ''\n label:\n en_US: Input Content\n ja_JP: Input Content\n pt_BR: Conte\u00fado de Entrada\n zh_Hans: \u8f93\u5165\u6587\u672c\n llm_description: The text you want to chunk.\n max: null\n min: null\n name: input_text\n options: []\n placeholder: null\n precision: null\n required: true\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: paragraph\n form: llm\n human_description:\n en_US: Split text into paragraphs based on separator and maximum chunk\n length, using split text as parent block or entire document as parent\n block and directly retrieve.\n ja_JP: Split text into paragraphs based on separator and maximum chunk\n length, using split text as parent block or entire document as parent\n block and directly retrieve.\n pt_BR: Dividir texto em par\u00e1grafos com base no separador e no comprimento\n m\u00e1ximo do bloco, usando o texto dividido como bloco pai ou documento\n completo como bloco pai e diretamente recuper\u00e1-lo.\n zh_Hans: \u6839\u636e\u5206\u9694\u7b26\u548c\u6700\u5927\u5757\u957f\u5ea6\u5c06\u6587\u672c\u62c6\u5206\u4e3a\u6bb5\u843d\uff0c\u4f7f\u7528\u62c6\u5206\u6587\u672c\u4f5c\u4e3a\u68c0\u7d22\u7684\u7236\u5757\u6216\u6574\u4e2a\u6587\u6863\u7528\u4f5c\u7236\u5757\u5e76\u76f4\u63a5\u68c0\u7d22\u3002\n label:\n en_US: Parent Mode\n ja_JP: Parent Mode\n pt_BR: Modo Pai\n zh_Hans: \u7236\u5757\u6a21\u5f0f\n llm_description: Split text into paragraphs based on separator and maximum\n chunk length, using split text as parent block or entire document as parent\n block and directly retrieve.\n max: null\n min: null\n name: parent_mode\n options:\n - label:\n en_US: Paragraph\n ja_JP: Paragraph\n pt_BR: Par\u00e1grafo\n zh_Hans: \u6bb5\u843d\n value: paragraph\n - label:\n en_US: Full Document\n ja_JP: Full Document\n pt_BR: Documento Completo\n zh_Hans: \u5168\u6587\n value: full_doc\n placeholder: null\n precision: null\n required: true\n scope: null\n template: null\n type: select\n - auto_generate: null\n default: '\n\n\n '\n form: llm\n human_description:\n en_US: Separator used for chunking\n ja_JP: Separator used for chunking\n pt_BR: Separador usado para divis\u00e3o\n zh_Hans: \u7528\u4e8e\u5206\u5757\u7684\u5206\u9694\u7b26\n label:\n en_US: Parent Delimiter\n ja_JP: Parent Delimiter\n pt_BR: Separador de Pai\n zh_Hans: \u7236\u5757\u5206\u9694\u7b26\n llm_description: The separator used to split chunks\n max: null\n min: null\n name: separator\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: 1024\n form: llm\n human_description:\n en_US: Maximum length for chunking\n ja_JP: Maximum length for chunking\n pt_BR: Comprimento m\u00e1ximo para divis\u00e3o\n zh_Hans: \u7528\u4e8e\u5206\u5757\u7684\u6700\u5927\u957f\u5ea6\n label:\n en_US: Maximum Parent Chunk Length\n ja_JP: Maximum Parent Chunk Length\n pt_BR: Comprimento M\u00e1ximo do Bloco Pai\n zh_Hans: \u6700\u5927\u7236\u5757\u957f\u5ea6\n llm_description: Maximum length allowed per chunk\n max: null\n min: null\n name: max_length\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: number\n - auto_generate: null\n default: '. '\n form: llm\n human_description:\n en_US: Separator used for subchunking\n ja_JP: Separator used for subchunking\n pt_BR: Separador usado para subdivis\u00e3o\n zh_Hans: \u7528\u4e8e\u5b50\u5206\u5757\u7684\u5206\u9694\u7b26\n label:\n en_US: Child Delimiter\n ja_JP: Child Delimiter\n pt_BR: Separador de Subdivis\u00e3o\n zh_Hans: \u5b50\u5206\u5757\u5206\u9694\u7b26\n llm_description: The separator used to split subchunks\n max: null\n min: null\n name: subchunk_separator\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: 512\n form: llm\n human_description:\n en_US: Maximum length for subchunking\n ja_JP: Maximum length for subchunking\n pt_BR: Comprimento m\u00e1ximo para subdivis\u00e3o\n zh_Hans: \u7528\u4e8e\u5b50\u5206\u5757\u7684\u6700\u5927\u957f\u5ea6\n label:\n en_US: Maximum Child Chunk Length\n ja_JP: Maximum Child Chunk Length\n pt_BR: Comprimento M\u00e1ximo de Subdivis\u00e3o\n zh_Hans: \u5b50\u5206\u5757\u6700\u5927\u957f\u5ea6\n llm_description: Maximum length allowed per subchunk\n max: null\n min: null\n name: subchunk_max_length\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: number\n - auto_generate: null\n default: 0\n form: llm\n human_description:\n en_US: Whether to remove consecutive spaces, newlines and tabs\n ja_JP: Whether to remove consecutive spaces, newlines and tabs\n pt_BR: Se deve remover espa\u00e7os extras no texto\n zh_Hans: \u662f\u5426\u79fb\u9664\u6587\u672c\u4e2d\u7684\u8fde\u7eed\u7a7a\u683c\u3001\u6362\u884c\u7b26\u548c\u5236\u8868\u7b26\n label:\n en_US: Replace consecutive spaces, newlines and tabs\n ja_JP: Replace consecutive spaces, newlines and tabs\n pt_BR: Substituir espa\u00e7os consecutivos, novas linhas e guias\n zh_Hans: \u66ff\u6362\u8fde\u7eed\u7a7a\u683c\u3001\u6362\u884c\u7b26\u548c\u5236\u8868\u7b26\n llm_description: Whether to remove consecutive spaces, newlines and tabs\n max: null\n min: null\n name: remove_extra_spaces\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n - auto_generate: null\n default: 0\n form: llm\n human_description:\n en_US: Whether to remove URLs and emails in the text\n ja_JP: Whether to remove URLs and emails in the text\n pt_BR: Se deve remover URLs e e-mails no texto\n zh_Hans: \u662f\u5426\u79fb\u9664\u6587\u672c\u4e2d\u7684URL\u548c\u7535\u5b50\u90ae\u4ef6\u5730\u5740\n label:\n en_US: Delete all URLs and email addresses\n ja_JP: Delete all URLs and email addresses\n pt_BR: Remover todas as URLs e e-mails\n zh_Hans: \u5220\u9664\u6240\u6709URL\u548c\u7535\u5b50\u90ae\u4ef6\u5730\u5740\n llm_description: Whether to remove URLs and emails in the text\n max: null\n min: null\n name: remove_urls_emails\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n params:\n input_text: ''\n max_length: ''\n parent_mode: ''\n remove_extra_spaces: ''\n remove_urls_emails: ''\n separator: ''\n subchunk_max_length: ''\n subchunk_separator: ''\n provider_id: langgenius\/parentchild_chunker\/parentchild_chunker\n provider_name: langgenius\/parentchild_chunker\/parentchild_chunker\n provider_type: builtin\n selected: false\n title: Parent-child Chunker\n tool_configurations: {}\n tool_description: Process documents into parent-child chunk structures\n tool_label: Parent-child Chunker\n tool_name: parentchild_chunker\n tool_node_version: '2'\n tool_parameters:\n input_text:\n type: mixed\n value: '{{#1758002850987.text#}}'\n max_length:\n type: variable\n value:\n - rag\n - shared\n - Maximum_Parent_Length\n parent_mode:\n type: variable\n value:\n - rag\n - shared\n - Parent_Mode\n remove_extra_spaces:\n type: variable\n value:\n - rag\n - shared\n - clean_1\n remove_urls_emails:\n type: variable\n value:\n - rag\n - shared\n - clean_2\n separator:\n type: mixed\n value: '{{#rag.shared.Parent_Delimiter#}}'\n subchunk_max_length:\n type: variable\n value:\n - rag\n - shared\n - Maximum_Child_Length\n subchunk_separator:\n type: mixed\n value: '{{#rag.shared.Child_Delimiter#}}'\n type: tool\n height: 52\n id: '1751336942081'\n position:\n x: 144.55897745117755\n y: 282\n positionAbsolute:\n x: 144.55897745117755\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n - data:\n author: TenTen\n desc: ''\n height: 446\n selected: true\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"In\n this step, the LLM is responsible for enriching and reorganizing content,\n along with images and tables. The goal is to maintain the integrity of image\n URLs and tables while providing contextual descriptions and summaries to\n enhance understanding. The content should be structured into well-organized\n paragraphs, using double newlines to separate them. The LLM should enrich\n the document by adding relevant descriptions for images and extracting key\n insights from tables, ensuring the content remains easy to retrieve within\n a Retrieval-Augmented Generation (RAG) system. The final output should preserve\n the original structure, making it more accessible for knowledge retrieval.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 240\n height: 446\n id: '1753967810859'\n position:\n x: -176.67459682201036\n y: 405.2790698865377\n positionAbsolute:\n x: -176.67459682201036\n y: 405.2790698865377\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 240\n - data:\n datasource_configurations: {}\n datasource_label: File\n datasource_name: upload-file\n datasource_parameters: {}\n fileExtensions:\n - pdf\n - doc\n - docx\n - pptx\n - ppt\n - jpg\n - png\n - jpeg\n plugin_id: langgenius\/file\n provider_name: file\n provider_type: local_file\n selected: false\n title: File\n type: datasource\n height: 52\n id: '1756915693835'\n position:\n x: -893.2836123260277\n y: 282\n positionAbsolute:\n x: -893.2836123260277\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n - data:\n context:\n enabled: false\n variable_selector: []\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: claude-3-5-sonnet-20240620\n provider: langgenius\/anthropic\/anthropic\n prompt_template:\n - id: beb97761-d30d-4549-9b67-de1b8292e43d\n role: system\n text: \"You are an AI document assistant. \\nYour tasks are:\\nEnrich the content\\\n \\ contextually:\\nAdd meaningful descriptions for each image.\\nSummarize\\\n \\ key information from each table.\\nOutput the enriched content\u00a0with clear\\\n \\ annotations showing the\u00a0corresponding image and table positions, so\\\n \\ the text can later be aligned back into the original document. Preserve\\\n \\ any ![image] URLs from the input text.\\nYou will receive two inputs:\\n\\\n The file and text\u00a0(may contain images url and tables).\\nThe final output\\\n \\ should be a\u00a0single, enriched version of the original document with ![image]\\\n \\ url preserved.\\nGenerate output directly without saying words like:\\\n \\ Here's the enriched version of the original text with the image description\\\n \\ inserted.\"\n - id: f92ef0cd-03a7-48a7-80e8-bcdc965fb399\n role: user\n text: The file is {{#1756915693835.file#}} and the text are\u00a0{{#1758027159239.text#}}.\n selected: false\n title: LLM\n type: llm\n vision:\n configs:\n detail: high\n variable_selector:\n - '1756915693835'\n - file\n enabled: true\n height: 88\n id: '1758002850987'\n position:\n x: -176.67459682201036\n y: 282\n positionAbsolute:\n x: -176.67459682201036\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n - data:\n is_team_authorization: true\n paramSchemas:\n - auto_generate: null\n default: null\n form: llm\n human_description:\n en_US: The file to be parsed(support pdf, ppt, pptx, doc, docx, png, jpg,\n jpeg)\n ja_JP: \u89e3\u6790\u3059\u308b\u30d5\u30a1\u30a4\u30eb(pdf\u3001ppt\u3001pptx\u3001doc\u3001docx\u3001png\u3001jpg\u3001jpeg\u3092\u30b5\u30dd\u30fc\u30c8)\n pt_BR: The file to be parsed(support pdf, ppt, pptx, doc, docx, png, jpg,\n jpeg)\n zh_Hans: \u7528\u4e8e\u89e3\u6790\u7684\u6587\u4ef6(\u652f\u6301 pdf, ppt, pptx, doc, docx, png, jpg, jpeg)\n label:\n en_US: file\n ja_JP: file\n pt_BR: file\n zh_Hans: file\n llm_description: The file to be parsed (support pdf, ppt, pptx, doc, docx,\n png, jpg, jpeg)\n max: null\n min: null\n name: file\n options: []\n placeholder: null\n precision: null\n required: true\n scope: null\n template: null\n type: file\n - auto_generate: null\n default: auto\n form: form\n human_description:\n en_US: (For local deployment v1 and v2) Parsing method, can be auto, ocr,\n or txt. Default is auto. If results are not satisfactory, try ocr\n ja_JP: \uff08\u30ed\u30fc\u30ab\u30eb\u30c7\u30d7\u30ed\u30a4\u30e1\u30f3\u30c8v1\u3068v2\u7528\uff09\u89e3\u6790\u65b9\u6cd5\u306f\u3001auto\u3001ocr\u3001\u307e\u305f\u306ftxt\u306e\u3044\u305a\u308c\u304b\u3067\u3059\u3002\u30c7\u30d5\u30a9\u30eb\u30c8\u306fauto\u3067\u3059\u3002\u7d50\u679c\u304c\u6e80\u8db3\u3067\u304d\u306a\u3044\u5834\u5408\u306f\u3001ocr\u3092\u8a66\u3057\u3066\u304f\u3060\u3055\u3044\n pt_BR: (For local deployment v1 and v2) Parsing method, can be auto, ocr,\n or txt. Default is auto. If results are not satisfactory, try ocr\n zh_Hans: \uff08\u7528\u4e8e\u672c\u5730\u90e8\u7f72v1\u548cv2\u7248\u672c\uff09\u89e3\u6790\u65b9\u6cd5\uff0c\u53ef\u4ee5\u662fauto, ocr, \u6216 txt\u3002\u9ed8\u8ba4\u662fauto\u3002\u5982\u679c\u7ed3\u679c\u4e0d\u7406\u60f3\uff0c\u8bf7\u5c1d\u8bd5ocr\n label:\n en_US: parse method\n ja_JP: \u89e3\u6790\u65b9\u6cd5\n pt_BR: parse method\n zh_Hans: \u89e3\u6790\u65b9\u6cd5\n llm_description: (For local deployment v1 and v2) Parsing method, can be\n auto, ocr, or txt. Default is auto. If results are not satisfactory, try\n ocr\n max: null\n min: null\n name: parse_method\n options:\n - icon: ''\n label:\n en_US: auto\n ja_JP: auto\n pt_BR: auto\n zh_Hans: auto\n value: auto\n - icon: ''\n label:\n en_US: ocr\n ja_JP: ocr\n pt_BR: ocr\n zh_Hans: ocr\n value: ocr\n - icon: ''\n label:\n en_US: txt\n ja_JP: txt\n pt_BR: txt\n zh_Hans: txt\n value: txt\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: select\n - auto_generate: null\n default: 1\n form: form\n human_description:\n en_US: (For official API and local deployment v2) Whether to enable formula\n recognition\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\u3068\u30ed\u30fc\u30ab\u30eb\u30c7\u30d7\u30ed\u30a4\u30e1\u30f3\u30c8v2\u7528\uff09\u6570\u5f0f\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\u304b\u3069\u3046\u304b\n pt_BR: (For official API and local deployment v2) Whether to enable formula\n recognition\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\u548c\u672c\u5730\u90e8\u7f72v2\u7248\u672c\uff09\u662f\u5426\u5f00\u542f\u516c\u5f0f\u8bc6\u522b\n label:\n en_US: Enable formula recognition\n ja_JP: \u6570\u5f0f\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\n pt_BR: Enable formula recognition\n zh_Hans: \u5f00\u542f\u516c\u5f0f\u8bc6\u522b\n llm_description: (For official API and local deployment v2) Whether to enable\n formula recognition\n max: null\n min: null\n name: enable_formula\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n - auto_generate: null\n default: 1\n form: form\n human_description:\n en_US: (For official API and local deployment v2) Whether to enable table\n recognition\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\u3068\u30ed\u30fc\u30ab\u30eb\u30c7\u30d7\u30ed\u30a4\u30e1\u30f3\u30c8v2\u7528\uff09\u8868\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\u304b\u3069\u3046\u304b\n pt_BR: (For official API and local deployment v2) Whether to enable table\n recognition\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\u548c\u672c\u5730\u90e8\u7f72v2\u7248\u672c\uff09\u662f\u5426\u5f00\u542f\u8868\u683c\u8bc6\u522b\n label:\n en_US: Enable table recognition\n ja_JP: \u8868\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\n pt_BR: Enable table recognition\n zh_Hans: \u5f00\u542f\u8868\u683c\u8bc6\u522b\n llm_description: (For official API and local deployment v2) Whether to enable\n table recognition\n max: null\n min: null\n name: enable_table\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n - auto_generate: null\n default: auto\n form: form\n human_description:\n en_US: '(For official API and local deployment v2) Specify document language,\n default ch, can be set to auto(local deployment need to specify the\n language, default ch), other optional value list see: https:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/version3.x\/pipeline_usage\/OCR.html#5'\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\u3068\u30ed\u30fc\u30ab\u30eb\u30c7\u30d7\u30ed\u30a4\u30e1\u30f3\u30c8v2\u7528\uff09\u30c9\u30ad\u30e5\u30e1\u30f3\u30c8\u8a00\u8a9e\u3092\u6307\u5b9a\u3057\u307e\u3059\u3002\u30c7\u30d5\u30a9\u30eb\u30c8\u306fch\u3067\u3001auto\u306b\u8a2d\u5b9a\u3067\u304d\u307e\u3059\u3002auto\u306e\u5834\u5408\uff08\u30ed\u30fc\u30ab\u30eb\u30c7\u30d7\u30ed\u30a4\u30e1\u30f3\u30c8\u3067\u306f\u8a00\u8a9e\u3092\u6307\u5b9a\u3059\u308b\u5fc5\u8981\u304c\u3042\u308a\u307e\u3059\u3002\u30c7\u30d5\u30a9\u30eb\u30c8\u306fch\u3067\u3059\uff09\u3001\u30e2\u30c7\u30eb\u306f\u30c9\u30ad\u30e5\u30e1\u30f3\u30c8\u8a00\u8a9e\u3092\u81ea\u52d5\u7684\u306b\u8b58\u5225\u3057\u307e\u3059\u3002\u4ed6\u306e\u30aa\u30d7\u30b7\u30e7\u30f3\u5024\u30ea\u30b9\u30c8\u306b\u3064\u3044\u3066\u306f\u3001\u6b21\u3092\u53c2\u7167\u3057\u3066\u304f\u3060\u3055\u3044\uff1ahttps:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/version3.x\/pipeline_usage\/OCR.html#5\n pt_BR: '(For official API and local deployment v2) Specify document language,\n default ch, can be set to auto(local deployment need to specify the\n language, default ch), other optional value list see: https:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/version3.x\/pipeline_usage\/OCR.html#5'\n zh_Hans: \uff08\u4ec5\u9650\u5b98\u65b9api\u548c\u672c\u5730\u90e8\u7f72v2\u7248\u672c\uff09\u6307\u5b9a\u6587\u6863\u8bed\u8a00\uff0c\u9ed8\u8ba4 ch\uff0c\u53ef\u4ee5\u8bbe\u7f6e\u4e3aauto\uff0c\u5f53\u4e3aauto\u65f6\u6a21\u578b\u4f1a\u81ea\u52a8\u8bc6\u522b\u6587\u6863\u8bed\u8a00\uff08\u672c\u5730\u90e8\u7f72\u9700\u8981\u6307\u5b9a\u660e\u786e\u7684\u8bed\u8a00\uff0c\u9ed8\u8ba4ch\uff09\uff0c\u5176\u4ed6\u53ef\u9009\u503c\u5217\u8868\u8be6\u89c1\uff1ahttps:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/version3.x\/pipeline_usage\/OCR.html#5\n label:\n en_US: Document language\n ja_JP: \u30c9\u30ad\u30e5\u30e1\u30f3\u30c8\u8a00\u8a9e\n pt_BR: Document language\n zh_Hans: \u6587\u6863\u8bed\u8a00\n llm_description: '(For official API and local deployment v2) Specify document\n language, default ch, can be set to auto(local deployment need to specify\n the language, default ch), other optional value list see: https:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/version3.x\/pipeline_usage\/OCR.html#5'\n max: null\n min: null\n name: language\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: 0\n form: form\n human_description:\n en_US: (For official API) Whether to enable OCR recognition\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\uff09OCR\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\u304b\u3069\u3046\u304b\n pt_BR: (For official API) Whether to enable OCR recognition\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\uff09\u662f\u5426\u5f00\u542fOCR\u8bc6\u522b\n label:\n en_US: Enable OCR recognition\n ja_JP: OCR\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\n pt_BR: Enable OCR recognition\n zh_Hans: \u5f00\u542fOCR\u8bc6\u522b\n llm_description: (For official API) Whether to enable OCR recognition\n max: null\n min: null\n name: enable_ocr\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n - auto_generate: null\n default: '[]'\n form: form\n human_description:\n en_US: '(For official API) Example: [\"docx\",\"html\"], markdown, json are\n the default export formats, no need to set, this parameter only supports\n one or more of docx, html, latex'\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\uff09\u4f8b\uff1a[\"docx\",\"html\"]\u3001markdown\u3001json\u306f\u30c7\u30d5\u30a9\u30eb\u30c8\u306e\u30a8\u30af\u30b9\u30dd\u30fc\u30c8\u5f62\u5f0f\u3067\u3042\u308a\u3001\u8a2d\u5b9a\u3059\u308b\u5fc5\u8981\u306f\u3042\u308a\u307e\u305b\u3093\u3002\u3053\u306e\u30d1\u30e9\u30e1\u30fc\u30bf\u306f\u3001docx\u3001html\u3001latex\u306e3\u3064\u306e\u5f62\u5f0f\u306e\u3044\u305a\u308c\u304b\u307e\u305f\u306f\u8907\u6570\u306e\u307f\u3092\u30b5\u30dd\u30fc\u30c8\u3057\u307e\u3059\n pt_BR: '(For official API) Example: [\"docx\",\"html\"], markdown, json are\n the default export formats, no need to set, this parameter only supports\n one or more of docx, html, latex'\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\uff09\u793a\u4f8b\uff1a[\"docx\",\"html\"],markdown\u3001json\u4e3a\u9ed8\u8ba4\u5bfc\u51fa\u683c\u5f0f\uff0c\u65e0\u987b\u8bbe\u7f6e\uff0c\u8be5\u53c2\u6570\u4ec5\u652f\u6301docx\u3001html\u3001latex\u4e09\u79cd\u683c\u5f0f\u4e2d\u7684\u4e00\u4e2a\u6216\u591a\u4e2a\n label:\n en_US: Extra export formats\n ja_JP: \u8ffd\u52a0\u306e\u30a8\u30af\u30b9\u30dd\u30fc\u30c8\u5f62\u5f0f\n pt_BR: Extra export formats\n zh_Hans: \u989d\u5916\u5bfc\u51fa\u683c\u5f0f\n llm_description: '(For official API) Example: [\"docx\",\"html\"], markdown,\n json are the default export formats, no need to set, this parameter only\n supports one or more of docx, html, latex'\n max: null\n min: null\n name: extra_formats\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: pipeline\n form: form\n human_description:\n en_US: '(For local deployment v2) Example: pipeline, vlm-transformers,\n vlm-sglang-engine, vlm-sglang-client, default is pipeline'\n ja_JP: \uff08\u30ed\u30fc\u30ab\u30eb\u30c7\u30d7\u30ed\u30a4\u30e1\u30f3\u30c8v2\u7528\uff09\u4f8b\uff1apipeline\u3001vlm-transformers\u3001vlm-sglang-engine\u3001vlm-sglang-client\u3001\u30c7\u30d5\u30a9\u30eb\u30c8\u306fpipeline\n pt_BR: '(For local deployment v2) Example: pipeline, vlm-transformers,\n vlm-sglang-engine, vlm-sglang-client, default is pipeline'\n zh_Hans: \uff08\u7528\u4e8e\u672c\u5730\u90e8\u7f72v2\u7248\u672c\uff09\u793a\u4f8b\uff1apipeline\u3001vlm-transformers\u3001vlm-sglang-engine\u3001vlm-sglang-client\uff0c\u9ed8\u8ba4\u503c\u4e3apipeline\n label:\n en_US: Backend type\n ja_JP: \u30d0\u30c3\u30af\u30a8\u30f3\u30c9\u30bf\u30a4\u30d7\n pt_BR: Backend type\n zh_Hans: \u89e3\u6790\u540e\u7aef\n llm_description: '(For local deployment v2) Example: pipeline, vlm-transformers,\n vlm-sglang-engine, vlm-sglang-client, default is pipeline'\n max: null\n min: null\n name: backend\n options:\n - icon: ''\n label:\n en_US: pipeline\n ja_JP: pipeline\n pt_BR: pipeline\n zh_Hans: pipeline\n value: pipeline\n - icon: ''\n label:\n en_US: vlm-transformers\n ja_JP: vlm-transformers\n pt_BR: vlm-transformers\n zh_Hans: vlm-transformers\n value: vlm-transformers\n - icon: ''\n label:\n en_US: vlm-sglang-engine\n ja_JP: vlm-sglang-engine\n pt_BR: vlm-sglang-engine\n zh_Hans: vlm-sglang-engine\n value: vlm-sglang-engine\n - icon: ''\n label:\n en_US: vlm-sglang-client\n ja_JP: vlm-sglang-client\n pt_BR: vlm-sglang-client\n zh_Hans: vlm-sglang-client\n value: vlm-sglang-client\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: select\n - auto_generate: null\n default: ''\n form: form\n human_description:\n en_US: '(For local deployment v2 when backend is vlm-sglang-client) Example:\n http:\/\/127.0.0.1:8000, default is empty'\n ja_JP: \uff08\u30ed\u30fc\u30ab\u30eb\u30c7\u30d7\u30ed\u30a4\u30e1\u30f3\u30c8v2\u7528 \u89e3\u6790\u5f8c\u7aef\u304cvlm-sglang-client\u306e\u5834\u5408\uff09\u4f8b\uff1ahttp:\/\/127.0.0.1:8000\u3001\u30c7\u30d5\u30a9\u30eb\u30c8\u306f\u7a7a\n pt_BR: '(For local deployment v2 when backend is vlm-sglang-client) Example:\n http:\/\/127.0.0.1:8000, default is empty'\n zh_Hans: \uff08\u7528\u4e8e\u672c\u5730\u90e8\u7f72v2\u7248\u672c \u89e3\u6790\u540e\u7aef\u4e3avlm-sglang-client\u65f6\uff09\u793a\u4f8b\uff1ahttp:\/\/127.0.0.1:8000\uff0c\u9ed8\u8ba4\u503c\u4e3a\u7a7a\n label:\n en_US: sglang-server url\n ja_JP: sglang-server\u30a2\u30c9\u30ec\u30b9\n pt_BR: sglang-server url\n zh_Hans: sglang-server\u5730\u5740\n llm_description: '(For local deployment v2 when backend is vlm-sglang-client)\n Example: http:\/\/127.0.0.1:8000, default is empty'\n max: null\n min: null\n name: sglang_server_url\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n params:\n backend: ''\n enable_formula: ''\n enable_ocr: ''\n enable_table: ''\n extra_formats: ''\n file: ''\n language: ''\n parse_method: ''\n sglang_server_url: ''\n provider_id: langgenius\/mineru\/mineru\n provider_name: langgenius\/mineru\/mineru\n provider_type: builtin\n selected: false\n title: Parse File\n tool_configurations:\n backend:\n type: constant\n value: pipeline\n enable_formula:\n type: constant\n value: 1\n enable_ocr:\n type: constant\n value: true\n enable_table:\n type: constant\n value: 1\n extra_formats:\n type: mixed\n value: '[]'\n language:\n type: mixed\n value: auto\n parse_method:\n type: constant\n value: auto\n sglang_server_url:\n type: mixed\n value: ''\n tool_description: a tool for parsing text, tables, and images, supporting\n multiple formats such as pdf, pptx, docx, etc. supporting multiple languages\n such as English, Chinese, etc.\n tool_label: Parse File\n tool_name: parse-file\n tool_node_version: '2'\n tool_parameters:\n file:\n type: variable\n value:\n - '1756915693835'\n - file\n type: tool\n height: 270\n id: '1758027159239'\n position:\n x: -544.9739996945534\n y: 282\n positionAbsolute:\n x: -544.9739996945534\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n viewport:\n x: 679.9701291615181\n y: -191.49392257836791\n zoom: 0.8239704766223018\n rag_pipeline_variables:\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: paragraph\n label: Parent Mode\n max_length: 48\n options:\n - paragraph\n - full_doc\n placeholder: null\n required: true\n tooltips: 'Parent Mode provides two options: paragraph mode splits text into paragraphs\n as parent chunks for retrieval, while full_doc mode uses the entire document\n as a single parent chunk (text beyond 10,000 tokens will be truncated).'\n type: select\n unit: null\n variable: Parent_Mode\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: \\n\\n\n label: Parent Delimiter\n max_length: 48\n options: []\n placeholder: null\n required: false\n tooltips: A delimiter is the character used to separate text. \\n\\n is recommended\n for splitting the original document into large parent chunks. You can also use\n special delimiters defined by yourself.\n type: text-input\n unit: null\n variable: Parent_Delimiter\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: 1024\n label: Maximum Parent Length\n max_length: 48\n options: []\n placeholder: null\n required: false\n tooltips: null\n type: number\n unit: tokens\n variable: Maximum_Parent_Length\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: \\n\n label: Child Delimiter\n max_length: 48\n options: []\n placeholder: null\n required: true\n tooltips: A delimiter is the character used to separate text. \\n is recommended\n for splitting parent chunks into small child chunks. You can also use special\n delimiters defined by yourself.\n type: text-input\n unit: null\n variable: Child_Delimiter\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: 256\n label: Maximum Child Length\n max_length: 48\n options: []\n placeholder: null\n required: true\n tooltips: ''\n type: number\n unit: tokens\n variable: Maximum_Child_Length\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: true\n label: Replace consecutive spaces, newlines and tabs.\n max_length: 48\n options: []\n placeholder: null\n required: false\n tooltips: null\n type: checkbox\n unit: null\n variable: clean_1\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: null\n label: Delete all URLs and email addresses.\n max_length: 48\n options: []\n placeholder: null\n required: false\n tooltips: ''\n type: checkbox\n unit: null\n variable: clean_2\n", @@ -6310,7 +6342,7 @@ "id": "103825d3-7018-43ae-bcf0-f3c001f3eb69", "name": "Contextual Enrichment Using LLM" }, -{ + "629cb5b8-490a-48bc-808b-ffc13085cb4f": { "chunk_structure": "hierarchical_model", "description": "This Knowledge Pipeline extracts images and tables from complex PDF documents for downstream processing.", "export_data": "dependencies:\n- current_identifier: null\n type: marketplace\n value:\n marketplace_plugin_unique_identifier: langgenius\/jina:0.0.8@d3a6766fbb80890d73fea7ea04803f3e1702c6e6bd621aafb492b86222a193dd\n- current_identifier: null\n type: marketplace\n value:\n marketplace_plugin_unique_identifier: langgenius\/parentchild_chunker:0.0.7@ee9c253e7942436b4de0318200af97d98d094262f3c1a56edbe29dcb01fbc158\n- current_identifier: null\n type: marketplace\n value:\n marketplace_plugin_unique_identifier: langgenius\/mineru:0.5.0@ca04f2dceb4107e3adf24839756954b7c5bcb7045d035dbab5821595541c093d\nkind: rag_pipeline\nrag_pipeline:\n description: ''\n icon: 87426868-91d6-4774-a535-5fd4595a77b3\n icon_background: null\n icon_type: image\n icon_url: data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAAKAAAACgCAYAAACLz2ctAAAAAXNSR0IArs4c6QAAAERlWElmTU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAA6ABAAMAAAABAAEAAKACAAQAAAABAAAAoKADAAQAAAABAAAAoAAAAACn7BmJAAARwElEQVR4Ae1dvXPcxhVfLMAP0RR1pL7MGVu8G7sXXdszotNYne1x6kgpktZSiiRNIrtMilgqnNZSb4\/lzm4i5i8w1TvDE+UZyZIlnihKOvIAbN5v7\/aIw93xPvBBHPDezBHYBbC7+O2Pb9++\/YAlMiIPHjwoO65btpQqK6VKVKySsqwV9fQpSliy6IcTubhYxrFTrJJqXe+Mz2+I8KgJoeh3IIRBTW1vt+MoXLWWlgRheo\/uqlmWVSVMa67jVJeXl6sHTx7dGb1HurK9uVnybHtNKXFBWAKEW1XCKvcrhb+tCdi+LBeX2ud80o3AaHipDUGkFErdJXJu2J63vliptAncnXr8MakQ8PH9+2tU9Av0omtCCZx3iZSSsLCE49j6iHPE+U+fCEnnCEOmTp\/uehbXzPWuizmNoFaC4CQdFxCE3V9\/bcd4vk8txpLwW\/f6FPZ9RT8c\/fZ9nSdESmGtK1veOvPGG3SerCRGQGg6V8rLxIwPg6QDUWzb1kTDcXrKaROu16v6T550RMuTJzvCHOhEYBS8PM8TIGmj4QrX9ejndiRG5Kj6lvj8zLlzNzsuxBiInYCaeI7zqeWrK8YuA+lmZqbF9PSUcIh0o2irUQCNEZeJTSoqXg0i4d7evial0ZIgopLWzdNvvvl53MDESsBfNrc+sqX6wth0juOIublZMUXHcSUqoOPmO6nPxYkXiFinn9GMIGLcGjEWApLWK7u2\/ZVpauMgniFAnICaNPN8TAIvaMXd3ZcHdqMlbjve1NXFSvSetIxaGU\/u3\/\/Uk\/aPIB+a1rm5Y+LEwnwkrRe1TPx8vAigBVssLYj51+Z0x5Dq+iNXNn58tLV1OWpOYxMQtt7jra0vqFd1HbYe7DsU8tjsTNQy8fMZRQB2PJQLjiQlS4mvwIEoxR2rCdZNrpTfUnd9FVrv2LHZxIiXRJMSBbCsP5sWXvX6nnj1qq5dPOQQ33D86Y\/HaZJH1oAgnyflHZAPfrrSieOJkS\/rlV3k8s1SS3eC6h4cABc82bizvfmgPComIxHQkA+9XPjwoI6bBRg1W74\/Dwig7sEBuNbIDCPFNDoJhyYgky8PlIn\/HUDChQgkHIqAvcg3ijM5\/tfmFLOEALgwLgmHIiANqX0bbHaZfFmq\/myUJUxCV+5\/S4qrNKh0AwnY7GY3OxwLx18baRhtUOZ8PV8IgITHiSOmY0KDE9cGveGhBHy0SY5GJa4gYe5wDIKSrwMB0zHBDCZw5+G9e1cOQ6YvAWH3kX2pnYzw8zVZfVhSfI0RaCIAroAzEJp6cu0w90xfApL6pEkFogSvN49uNIHlv8MjAD8hRsdISq7d+Krfkz0J2Gp6PwKT51pM7pcAxzMC\/RDQY8fNpnjtV5op1eu+ngSUUmnjEeTjprcXbBw3DALoO5imWJA516tX3EVAmt1yDS4XEK816DxMXnwPI9ATATTFmJ5H5lx5X8quDkkXAZXvX0ZK8\/NzPRPkSEZgVAQwKRlCq34+DWvBDgLC9oP2w\/yvKLOYdW78hxFoIQAuQQuSNNcJBZDpIKCx\/bjpDSDEp7EgYLQgjWR8GEywTcBHmz\/r9bls+wXh4fO4EIAWbDmn1x5v3l8z6bYJKKV3GZFTtEyShRFIAoHp5kxq4Ut\/zaTfJqAS8gIiufk10PAxbgRajmloQs01pK+n5KNn4kp7GxEnlwZOYMBtqUl4inlqGeckoywt5MfODbXajp7G7\/jeIrYB0RoQe7UAb+755oR1GX0NOKYlzZ6GGM5pAhIzVxFp074sLIxAkghg7x8I7VezhmPTBrSs8wiwBgQKLEkigLVEEIyM4Njs8iqLAtQNsdt9ElzLhGTJhskEIBNeCGxG9YLegaZpaaXXYlyzCcbqJhZGIEkEYAdCjAaUD2jiKSJ41gtQYEkaAd0RoYkuEOyKK2mMroyA3YrEOQsjkCQCRgs6dbcsaYtc7fizZFM1Jpkxp80IAAHTE7ZsVZbkgikjkptgoMCSBgJGAxL3SmiMmxqwZRymUQDOo9gIGAKCe9L0RgKRxUaH3z5xBExrS5xbaTv+9FSZxLPmDBiBTgSId9YKorLohO4sKofygoBRdp5Si20NmJeX4\/fIPgLG40JEPMEEzH595bqEtF7Ool4wLUWa0F7wr+\/\/JlMVdOrOfzrKY8p3\/C9\/FjMXL3ZcK2rADHrQHtPkiBa+dsOYdrmooCT93s\/\/8U+x9\/33SWczcelzE5xilYGEjY2NFHPMflZMwJTraOdvfxfuTz+lnGt2s3O8bb0URPheA+NxsZeU5\/N1Qqp2d8Wzq38SJ774l3DefrvzYgZDSazJ0V\/r3Hmu3xZTEHgoLuWKNyT0Hj5MOedsZBfo8OqhOCbgEdQLSLhDmrCIJOwg4BFgz1m2EAD5ikpCQwIHX9SGyJjWAydhM5jC5vFoSLhANqH9+uuZf8W4bHppNZd\/xN\/ryDyE2SugIWERm2MmYEb4aEgI27BIwgTMUG2DhDXqmBSJhEzADBEQRfHISV0kEjIBM0ZAQ0KMmBRBmIAZrWWMGWPsOO\/CBMxwDWP2TN5JyATMMAFRNJBw98t\/Z7yU4xePCTg+dqk9Wf\/6a\/Hy1q3U8kszIyZgmmhHyOvlzVu5JCETMAIp0n40jyRkAqbNooj55Y2ETMCIhDiKx0HCV19\/cxRZx54nEzB2SNNJ8MWXX+ZikRMTMB2+JJJLHnyE\/FmkRKhxkGh4nfDBFT4DAqwBmQdHigAT8Ejh58yZgMyBI0WAbcCY4Td7wcScbN\/kJt3GZA3Yt2r5QhoIMAHTQJnz6IsAE7AvNHwhDQSYgGmgzHn0RYAJ2BcavpAGAkzANFDmPPoiwATsCw1fSAOBifcDTrofLI1KznIerAGzXDsFKBsTsACVnOVXZAJmuXYKUDYmYAEqOcuvyATMcu0UoGxMwAJUcpZfkQmY5dopQNkmzg846nw7m77Fge9xzH7wgZhaPT+wSodN35qf1+kibef8eTHz3rsD0+51w7D59Xq2V9yk+UUnjoC9QD8sDhs+4odNfqZWV8U8fTQwjs3AsYsptlDTn96ivVt2iZDT770n5i79Lpb0D3unPF0rVBMMstT+8MdEPpUFQoLkSD8vi8bTIHqhCAhAQRR8KiupHemRPhaN53lLtTiJOfFN8CCbp7FxV9RJM+398EMbN5Bkl3YfxffaBkm\/9P2Hv2gSI2337t0uQmNLNeSD7wSPIv3yGyWNSbp34gk4CGx0PPCD3RfcY8\/Yb7ALxxH5+lmBn+nY7H3\/g04\/qFnRJDtvvSWO\/faTcbIoxDOFaYLnLl\/SnZBgrYI0ccnMxQ9Er68doTnmz7P2R7kwBAQE6KEGpUFNZ5wCLdubhPndYjcqfoUiYPj7vMHmMiqQ5nmQEK6eoKC5hz3I0o1AoQgI53EaArsybFvWY2zu03iHtPIoFAHRIw5KWCMGr0U9n363c2QEznCWbgQKRcB6wBUDKOTZs92IxBRjescmubjtTZPupB9z74YxFQQXDNwiQZm9eDEYjPU8PNznD2kDjjo2POl+w1wTEIa\/+9P\/tH9Oj9kGKAaCTI85gSCQTN\/TsL3JnZDeUE08AUfVGIAB5IC7hOXoESiUDQi4QT4MwYWbyLirIqzxwhox7vwmNb2J14CjAB\/ndKxB+aLpD8qwhJ90my74zsOc556Akmy9GXKJYK5euGc6DEDj3hMefkuyxz1uGbPw3MQTMKsao\/5N54dkZugfgKUbgcLZgN0QxB+DSQ7hYT5niOUA8Zck+yk6\/vZTXUpfedkv7QSUEMQLTvtCkWdoPcqwNmDWX9F\/8iSWIvq1Zzod1oCxwNlMBOTb6THbGlPBWHoj4FhC1JQQJaWUsCwKsYyFwCuy+fARwbD7Ze7Spdxov7GA6fEQuNaSmkOnNQowAQ0kQx4xJb9BEwwwHR\/T8sPEQzJoeln7dQPaQUB7cVGQ7hOytCCk5BY5DNc4Iy2GfMf\/+pdwchMXlidPxl9m3xfSniLWCTHxbpj40YmWIkY80OzyOpDhcGQCDofTwLtAvGOffKKJx8NuA+Fq38AEbEMx2glIBtfKFG3LgVEW5+239DjzaKkU826\/1QlRQtWsx1tbd8gIXFtYmBdTDvOxmJRI960brit2dmiNjCXWudeRLvacWwgBEBBuGKH8tm8mdAsHGYHkEJDkk9FjIgHfTHK5ccqMACHgeb7GgdwwVW6CmRLpI3AwEiIkWIgSeOQcZGEE0kCg3QtW6t6BDRhgZRqF4DyKi0DA3KtJy7eanRAmYHEZkfKb+8YGtKyqVI5VRf6uy\/MBU66HwmbXboI9qyZd160CiYBaLCww\/OLpIOC3+hvurFOVy5VKFdkikn2B6VRA0XMxBFxeXm66YSyhqgCFxuaKjg2\/f8IIuJ4x9dQGstKDv8qyaAM7UW40XDEzM51wEUZLPq41CKPlmp+7E5nPFwEe0wEhp989JKMd0Rb5YxA4YCdCLIxA\/AhgIgKEiKc1YHMkxLLWEelxTxgwsCSIgPG20PqjAwLanreOPKEBuSOSIPqcNLn7mhrQcE7bgIuVSo3mBa6TK2bN9T0xJbM7LzBrNk3WOJVlm9k0v9Td3QDngF2zCcaZUv\/FYX+\/gQMLIxA7Anv1fZ0m+Vo01xA4IKAv1xGxt9e8CecsjECcCLQ1oO\/fNOm2CXi68uY6pkhjRKR9o7mLj4xARASg2PRgB82+OlOp6A4IkmwTUKev1Hc4vnpZ10H+wwjEhUDdtKyW+DyYZgcBnaZqrEEDshYMwsTnURAAl9D7JduveubcuZvBtDoI2OyZqBu4gbVgECY+j4LA7u5L\/Ti5+G6F0+kgIC6SFrxOY8JVsLZe3wvfz2FGYCQEgrbf2crKZ+GHuwgILSh96ypufPmqzo7pMGIcHhoBLPMAh7SEbD+TSBcBceFU5dxt0yPefdFUn+YBPjICwyIAM05PvbLE7bDtZ9LoSUBcpGG539Ohtt9ocFNs0OLj0AjAfNvb1z7lmutN6Ra118N9CagnqvpKd5mhRnnVXC\/4OK4XAsGmV1ni6nJludrrPsT1JSAunq6sXKfJqjfgnMZeHkxCoMJyGALgCLgCzlCv90a\/ptekcSgBcZPt+59h8Bht+fPnL7hTYpDjYxcCIB040hzxUBtnKitXum4KRQwkIHrFru9\/DNeMR9O1nj0ndvM+MiEYOQjyPUMriSl95HD2\/OmPh0FlIAGRCOxBUq3vMwmHgbR493STb+r9w+y+IEJDERAP9CIh24RBKIt5Dg50ar7hyQfEhiYgbg6TkDsmQKW4YjocB83uaOQDciMREA8YEpqOybNnz9lPCGAKJvDzoe5Nh8PzRycfIBuZgHgIJDy9svKOcdG8ePlKYMCZm2Sgk28xPV3UOc7hanlB\/YNhbb4wOmMR0CRyamXlivKFHjGB1xtNMs+oNujk7witt13bERgdI6kJX12Fq6XSWt8xzhtHIiAyPFM5d5MWMr1DY8e3oY4xdoxC8nzCcaojm8+gLqFcjNbDPAHXn3oHAxVRS2xFTSD4\/KPNrctCqmuWsMqIx6772Gkhym4L4VVevCoOyPaXOPEC8TChwCgT+Peoxbt6FpNVYpJYCWjK9Hjz3mdKikuGiPgEmCbj7PTIn4KIE1BTvjwfo+AFmw5rw7EyEqYUwi1Bc3tjV\/jXozS3JrHgMRECmgzCGtHEg4y2Y2sySlsKx7bNpa5jFEC7EitAxLB46Q4EEWyf9gOCGwW7YuiNCQ5Ip7\/jQSz8bpeWasRNPFMViRLQZPJo8+dV2vjjsiXFBXorOu8WaEmbfvhkLEipj3SOD2oj3oh96hRtbN1ZbNyLX5HEECj8zo3Hj3UUrmMjSLl0sukqoXPEYWsMfY3s9Z5C9p3wsEZcruuVkj1vii8y9Vrb3NwsHRf2mpJqlVhzntAo9yMlXtN80d28slxcMqd87IHAKHhhWz7sjKY8bBZurT8X3npSmq5HUXVU6gTsV5AHmw\/KjnDLBEqJyFmm+0oEzop6+pQ6XQJhLdbiYonCJRPGkT43i3BHXPB6Ts9rhFUt\/G7+9nYVcWS94VrNWloSrd3PatgPnLCqusKpjuu3Q9pxyv8BVb3XBNS3Vn0AAAAASUVORK5CYII=\n name: Complex PDF with Images & Tables\nversion: 0.1.0\nworkflow:\n conversation_variables: []\n environment_variables: []\n features: {}\n graph:\n edges:\n - data:\n isInLoop: false\n sourceType: datasource\n targetType: tool\n id: 1750400203722-source-1751281136356-target\n selected: false\n source: '1750400203722'\n sourceHandle: source\n target: '1751281136356'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInLoop: false\n sourceType: tool\n targetType: knowledge-index\n id: 1751338398711-source-1750400198569-target\n selected: false\n source: '1751338398711'\n sourceHandle: source\n target: '1750400198569'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInLoop: false\n sourceType: tool\n targetType: tool\n id: 1751281136356-source-1751338398711-target\n selected: false\n source: '1751281136356'\n sourceHandle: source\n target: '1751338398711'\n targetHandle: target\n type: custom\n zIndex: 0\n nodes:\n - data:\n chunk_structure: hierarchical_model\n embedding_model: jina-embeddings-v2-base-en\n embedding_model_provider: langgenius\/jina\/jina\n index_chunk_variable_selector:\n - '1751338398711'\n - result\n indexing_technique: high_quality\n keyword_number: 10\n retrieval_model:\n reranking_enable: true\n reranking_mode: reranking_model\n reranking_model:\n reranking_model_name: jina-reranker-v1-base-en\n reranking_provider_name: langgenius\/jina\/jina\n score_threshold: 0\n score_threshold_enabled: false\n search_method: hybrid_search\n top_k: 3\n weights: null\n selected: true\n title: Knowledge Base\n type: knowledge-index\n height: 114\n id: '1750400198569'\n position:\n x: 355.92518399555183\n y: 282\n positionAbsolute:\n x: 355.92518399555183\n y: 282\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n - data:\n datasource_configurations: {}\n datasource_label: File\n datasource_name: upload-file\n datasource_parameters: {}\n fileExtensions:\n - txt\n - markdown\n - mdx\n - pdf\n - html\n - xlsx\n - xls\n - vtt\n - properties\n - doc\n - docx\n - csv\n - eml\n - msg\n - pptx\n - xml\n - epub\n - ppt\n - md\n plugin_id: langgenius\/file\n provider_name: file\n provider_type: local_file\n selected: false\n title: File Upload\n type: datasource\n height: 52\n id: '1750400203722'\n position:\n x: -579\n y: 282\n positionAbsolute:\n x: -579\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n - data:\n author: TenTen\n desc: ''\n height: 337\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Currently\n we support 4 types of \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Data\n Sources\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\":\n File Upload, Online Drive, Online Doc, and Web Crawler. Different types\n of Data Sources have different input and output types. The output of File\n Upload and Online Drive are files, while the output of Online Doc and WebCrawler\n are pages. You can find more Data Sources on our Marketplace.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"A\n Knowledge Pipeline can have multiple data sources. Each data source can\n be selected more than once with different settings. Each added data source\n is a tab on the add file interface. However, each time the user can only\n select one data source to import the file and trigger its subsequent processing.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 358\n height: 337\n id: '1751264451381'\n position:\n x: -990.8091030156684\n y: 282\n positionAbsolute:\n x: -990.8091030156684\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 358\n - data:\n author: TenTen\n desc: ''\n height: 260\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"A\n \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Knowledge\n Pipeline\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\n starts with Data Source as the starting node and ends with the knowledge\n base node. The general steps are: import documents from the data source\n \u2192 use extractor to extract document content \u2192 split and clean content into\n structured chunks \u2192 store in the knowledge base.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"The\n user input variables required by the Knowledge Pipeline node must be predefined\n and managed via the Input Field section located in the top-right corner\n of the orchestration canvas. It determines what input fields the end users\n will see and need to fill in when importing files to the knowledge base\n through this pipeline.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Unique\n Inputs: Input fields defined here are only available to the selected data\n source and its downstream nodes.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Global\n Inputs: These input fields are shared across all subsequent nodes after\n the data source and are typically set during the Process Documents step.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"For\n more information, see \",\"type\":\"text\",\"version\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"https:\/\/docs.dify.ai\/en\/guides\/knowledge-base\/knowledge-pipeline\/knowledge-pipeline-orchestration.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"link\",\"version\":1,\"rel\":\"noreferrer\",\"target\":null,\"title\":null,\"url\":\"https:\/\/docs.dify.ai\/en\/guides\/knowledge-base\/knowledge-pipeline\/knowledge-pipeline-orchestration\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 1182\n height: 260\n id: '1751266376760'\n position:\n x: -579\n y: -22.64803881585007\n positionAbsolute:\n x: -579\n y: -22.64803881585007\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 1182\n - data:\n author: TenTen\n desc: ''\n height: 541\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"A\n document extractor for large language models (LLMs) like MinerU is a tool\n that preprocesses and converts diverse document types into structured, clean,\n and machine-readable data. This structured data can then be used to train\n or augment LLMs and retrieval-augmented generation (RAG) systems by providing\n them with accurate, well-organized content from varied sources. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"MinerU\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\n is an advanced open-source document extractor designed specifically to convert\n complex, unstructured documents\u2014such as PDFs, Word files, and PPTs\u2014into\n high-quality, machine-readable formats like Markdown and JSON. MinerU addresses\n challenges in document parsing such as layout detection, formula recognition,\n and multi-language support, which are critical for generating high-quality\n training corpora for LLMs.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1,\"textFormat\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 240\n height: 541\n id: '1751266402561'\n position:\n x: -263.7680017647218\n y: 558.328085421591\n positionAbsolute:\n x: -263.7680017647218\n y: 558.328085421591\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 240\n - data:\n author: TenTen\n desc: ''\n height: 554\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Parent-Child\n Mode\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\n addresses the dilemma of context and precision by leveraging a two-tier\n hierarchical approach that effectively balances the trade-off between accurate\n matching and comprehensive contextual information in RAG systems. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Here\n is the essential mechanism of this structured, two-level information access:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"-\n Query Matching with Child Chunks: Small, focused pieces of information,\n often as concise as a single sentence within a paragraph, are used to match\n the user''s query. These child chunks enable precise and relevant initial\n retrieval.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"-\n Contextual Enrichment with Parent Chunks: Larger, encompassing sections\u2014such\n as a paragraph, a section, or even an entire document\u2014that include the matched\n child chunks are then retrieved. These parent chunks provide comprehensive\n context for the Language Model (LLM).\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1,\"textFormat\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 240\n height: 554\n id: '1751266447821'\n position:\n x: 42.95253988413964\n y: 366.1915342509804\n positionAbsolute:\n x: 42.95253988413964\n y: 366.1915342509804\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 240\n - data:\n author: TenTen\n desc: ''\n height: 411\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"The\n knowledge base provides two indexing methods:\u00a0\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"High-Quality\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\u00a0and\u00a0\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Economical\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\",\n each with different retrieval strategies. High-Quality mode uses embeddings\n for vectorization and supports vector, full-text, and hybrid retrieval,\n offering more accurate results but higher resource usage. Economical mode\n uses keyword-based inverted indexing with no token consumption but lower\n accuracy; upgrading to High-Quality is possible, but downgrading requires\n creating a new knowledge base.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0,\"textStyle\":\"\"},{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"*\n Parent-Child Mode\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\u00a0and\u00a0\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Q&A\n Mode\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\u00a0only\n support the\u00a0\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"High-Quality\",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"\u00a0indexing\n method.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"start\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1,\"textStyle\":\"\"}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1,\"textFormat\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 240\n height: 411\n id: '1751266580099'\n position:\n x: 355.92518399555183\n y: 434.6494699299023\n positionAbsolute:\n x: 355.92518399555183\n y: 434.6494699299023\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 240\n - data:\n credential_id: fd1cbc33-1481-47ee-9af2-954b53d350e0\n is_team_authorization: false\n output_schema:\n properties:\n full_zip_url:\n description: The zip URL of the complete parsed result\n type: string\n images:\n description: The images extracted from the file\n items:\n type: object\n type: array\n type: object\n paramSchemas:\n - auto_generate: null\n default: null\n form: llm\n human_description:\n en_US: the file to be parsed(support pdf, ppt, pptx, doc, docx, png, jpg,\n jpeg)\n ja_JP: \u89e3\u6790\u3059\u308b\u30d5\u30a1\u30a4\u30eb(pdf\u3001ppt\u3001pptx\u3001doc\u3001docx\u3001png\u3001jpg\u3001jpeg\u3092\u30b5\u30dd\u30fc\u30c8)\n pt_BR: the file to be parsed(support pdf, ppt, pptx, doc, docx, png, jpg,\n jpeg)\n zh_Hans: \u7528\u4e8e\u89e3\u6790\u7684\u6587\u4ef6(\u652f\u6301 pdf, ppt, pptx, doc, docx, png, jpg, jpeg)\n label:\n en_US: file\n ja_JP: file\n pt_BR: file\n zh_Hans: file\n llm_description: the file to be parsed (support pdf, ppt, pptx, doc, docx,\n png, jpg, jpeg)\n max: null\n min: null\n name: file\n options: []\n placeholder: null\n precision: null\n required: true\n scope: null\n template: null\n type: file\n - auto_generate: null\n default: auto\n form: form\n human_description:\n en_US: (For local deployment service)Parsing method, can be auto, ocr,\n or txt. Default is auto. If results are not satisfactory, try ocr\n ja_JP: \uff08\u30ed\u30fc\u30ab\u30eb\u30c7\u30d7\u30ed\u30a4\u30e1\u30f3\u30c8\u30b5\u30fc\u30d3\u30b9\u7528\uff09\u89e3\u6790\u65b9\u6cd5\u306f\u3001auto\u3001ocr\u3001\u307e\u305f\u306ftxt\u306e\u3044\u305a\u308c\u304b\u3067\u3059\u3002\u30c7\u30d5\u30a9\u30eb\u30c8\u306fauto\u3067\u3059\u3002\u7d50\u679c\u304c\u6e80\u8db3\u3067\u304d\u306a\u3044\u5834\u5408\u306f\u3001ocr\u3092\u8a66\u3057\u3066\u304f\u3060\u3055\u3044\n pt_BR: (For local deployment service)Parsing method, can be auto, ocr,\n or txt. Default is auto. If results are not satisfactory, try ocr\n zh_Hans: \uff08\u7528\u4e8e\u672c\u5730\u90e8\u7f72\u670d\u52a1\uff09\u89e3\u6790\u65b9\u6cd5\uff0c\u53ef\u4ee5\u662fauto, ocr, \u6216 txt\u3002\u9ed8\u8ba4\u662fauto\u3002\u5982\u679c\u7ed3\u679c\u4e0d\u7406\u60f3\uff0c\u8bf7\u5c1d\u8bd5ocr\n label:\n en_US: parse method\n ja_JP: \u89e3\u6790\u65b9\u6cd5\n pt_BR: parse method\n zh_Hans: \u89e3\u6790\u65b9\u6cd5\n llm_description: Parsing method, can be auto, ocr, or txt. Default is auto.\n If results are not satisfactory, try ocr\n max: null\n min: null\n name: parse_method\n options:\n - label:\n en_US: auto\n ja_JP: auto\n pt_BR: auto\n zh_Hans: auto\n value: auto\n - label:\n en_US: ocr\n ja_JP: ocr\n pt_BR: ocr\n zh_Hans: ocr\n value: ocr\n - label:\n en_US: txt\n ja_JP: txt\n pt_BR: txt\n zh_Hans: txt\n value: txt\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: select\n - auto_generate: null\n default: 1\n form: form\n human_description:\n en_US: (For official API) Whether to enable formula recognition\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\uff09\u6570\u5f0f\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\u304b\u3069\u3046\u304b\n pt_BR: (For official API) Whether to enable formula recognition\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\uff09\u662f\u5426\u5f00\u542f\u516c\u5f0f\u8bc6\u522b\n label:\n en_US: Enable formula recognition\n ja_JP: \u6570\u5f0f\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\n pt_BR: Enable formula recognition\n zh_Hans: \u5f00\u542f\u516c\u5f0f\u8bc6\u522b\n llm_description: (For official API) Whether to enable formula recognition\n max: null\n min: null\n name: enable_formula\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n - auto_generate: null\n default: 1\n form: form\n human_description:\n en_US: (For official API) Whether to enable table recognition\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\uff09\u8868\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\u304b\u3069\u3046\u304b\n pt_BR: (For official API) Whether to enable table recognition\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\uff09\u662f\u5426\u5f00\u542f\u8868\u683c\u8bc6\u522b\n label:\n en_US: Enable table recognition\n ja_JP: \u8868\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\n pt_BR: Enable table recognition\n zh_Hans: \u5f00\u542f\u8868\u683c\u8bc6\u522b\n llm_description: (For official API) Whether to enable table recognition\n max: null\n min: null\n name: enable_table\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n - auto_generate: null\n default: doclayout_yolo\n form: form\n human_description:\n en_US: '(For official API) Optional values: doclayout_yolo, layoutlmv3,\n default value is doclayout_yolo. doclayout_yolo is a self-developed\n model with better effect'\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\uff09\u30aa\u30d7\u30b7\u30e7\u30f3\u5024\uff1adoclayout_yolo\u3001layoutlmv3\u3001\u30c7\u30d5\u30a9\u30eb\u30c8\u5024\u306f doclayout_yolo\u3002doclayout_yolo\n \u306f\u81ea\u5df1\u958b\u767a\u30e2\u30c7\u30eb\u3067\u3001\u52b9\u679c\u304c\u3088\u308a\u826f\u3044\n pt_BR: '(For official API) Optional values: doclayout_yolo, layoutlmv3,\n default value is doclayout_yolo. doclayout_yolo is a self-developed\n model with better effect'\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\uff09\u53ef\u9009\u503c\uff1adoclayout_yolo\u3001layoutlmv3\uff0c\u9ed8\u8ba4\u503c\u4e3a doclayout_yolo\u3002doclayout_yolo\n \u4e3a\u81ea\u7814\u6a21\u578b\uff0c\u6548\u679c\u66f4\u597d\n label:\n en_US: Layout model\n ja_JP: \u30ec\u30a4\u30a2\u30a6\u30c8\u691c\u51fa\u30e2\u30c7\u30eb\n pt_BR: Layout model\n zh_Hans: \u5e03\u5c40\u68c0\u6d4b\u6a21\u578b\n llm_description: '(For official API) Optional values: doclayout_yolo, layoutlmv3,\n default value is doclayout_yolo. doclayout_yolo is a self-developed model\n withbetter effect'\n max: null\n min: null\n name: layout_model\n options:\n - label:\n en_US: doclayout_yolo\n ja_JP: doclayout_yolo\n pt_BR: doclayout_yolo\n zh_Hans: doclayout_yolo\n value: doclayout_yolo\n - label:\n en_US: layoutlmv3\n ja_JP: layoutlmv3\n pt_BR: layoutlmv3\n zh_Hans: layoutlmv3\n value: layoutlmv3\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: select\n - auto_generate: null\n default: auto\n form: form\n human_description:\n en_US: '(For official API) Specify document language, default ch, can\n be set to auto, when auto, the model will automatically identify document\n language, other optional value list see: https:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/ppocr\/blog\/multi_languages.html#5'\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\uff09\u30c9\u30ad\u30e5\u30e1\u30f3\u30c8\u8a00\u8a9e\u3092\u6307\u5b9a\u3057\u307e\u3059\u3002\u30c7\u30d5\u30a9\u30eb\u30c8\u306fch\u3067\u3001auto\u306b\u8a2d\u5b9a\u3067\u304d\u307e\u3059\u3002auto\u306e\u5834\u5408\u3001\u30e2\u30c7\u30eb\u306f\u30c9\u30ad\u30e5\u30e1\u30f3\u30c8\u8a00\u8a9e\u3092\u81ea\u52d5\u7684\u306b\u8b58\u5225\u3057\u307e\u3059\u3002\u4ed6\u306e\u30aa\u30d7\u30b7\u30e7\u30f3\u5024\u30ea\u30b9\u30c8\u306b\u3064\u3044\u3066\u306f\u3001\u6b21\u3092\u53c2\u7167\u3057\u3066\u304f\u3060\u3055\u3044\uff1ahttps:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/ppocr\/blog\/multi_languages.html#5\n pt_BR: '(For official API) Specify document language, default ch, can\n be set to auto, when auto, the model will automatically identify document\n language, other optional value list see: https:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/ppocr\/blog\/multi_languages.html#5'\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\uff09\u6307\u5b9a\u6587\u6863\u8bed\u8a00\uff0c\u9ed8\u8ba4 ch\uff0c\u53ef\u4ee5\u8bbe\u7f6e\u4e3aauto\uff0c\u5f53\u4e3aauto\u65f6\u6a21\u578b\u4f1a\u81ea\u52a8\u8bc6\u522b\u6587\u6863\u8bed\u8a00\uff0c\u5176\u4ed6\u53ef\u9009\u503c\u5217\u8868\u8be6\u89c1\uff1ahttps:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/ppocr\/blog\/multi_languages.html#5\n label:\n en_US: Document language\n ja_JP: \u30c9\u30ad\u30e5\u30e1\u30f3\u30c8\u8a00\u8a9e\n pt_BR: Document language\n zh_Hans: \u6587\u6863\u8bed\u8a00\n llm_description: '(For official API) Specify document language, default\n ch, can be set to auto, when auto, the model will automatically identify\n document language, other optional value list see: https:\/\/paddlepaddle.github.io\/PaddleOCR\/latest\/ppocr\/blog\/multi_languages.html#5'\n max: null\n min: null\n name: language\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: 0\n form: form\n human_description:\n en_US: (For official API) Whether to enable OCR recognition\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\uff09OCR\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\u304b\u3069\u3046\u304b\n pt_BR: (For official API) Whether to enable OCR recognition\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\uff09\u662f\u5426\u5f00\u542fOCR\u8bc6\u522b\n label:\n en_US: Enable OCR recognition\n ja_JP: OCR\u8a8d\u8b58\u3092\u6709\u52b9\u306b\u3059\u308b\n pt_BR: Enable OCR recognition\n zh_Hans: \u5f00\u542fOCR\u8bc6\u522b\n llm_description: (For official API) Whether to enable OCR recognition\n max: null\n min: null\n name: enable_ocr\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n - auto_generate: null\n default: '[]'\n form: form\n human_description:\n en_US: '(For official API) Example: [\"docx\",\"html\"], markdown, json are\n the default export formats, no need to set, this parameter only supports\n one or more of docx, html, latex'\n ja_JP: \uff08\u516c\u5f0fAPI\u7528\uff09\u4f8b\uff1a[\"docx\",\"html\"]\u3001markdown\u3001json\u306f\u30c7\u30d5\u30a9\u30eb\u30c8\u306e\u30a8\u30af\u30b9\u30dd\u30fc\u30c8\u5f62\u5f0f\u3067\u3042\u308a\u3001\u8a2d\u5b9a\u3059\u308b\u5fc5\u8981\u306f\u3042\u308a\u307e\u305b\u3093\u3002\u3053\u306e\u30d1\u30e9\u30e1\u30fc\u30bf\u306f\u3001docx\u3001html\u3001latex\u306e3\u3064\u306e\u5f62\u5f0f\u306e\u3044\u305a\u308c\u304b\u307e\u305f\u306f\u8907\u6570\u306e\u307f\u3092\u30b5\u30dd\u30fc\u30c8\u3057\u307e\u3059\n pt_BR: '(For official API) Example: [\"docx\",\"html\"], markdown, json are\n the default export formats, no need to set, this parameter only supports\n one or more of docx, html, latex'\n zh_Hans: \uff08\u7528\u4e8e\u5b98\u65b9API\uff09\u793a\u4f8b\uff1a[\"docx\",\"html\"],markdown\u3001json\u4e3a\u9ed8\u8ba4\u5bfc\u51fa\u683c\u5f0f\uff0c\u65e0\u987b\u8bbe\u7f6e\uff0c\u8be5\u53c2\u6570\u4ec5\u652f\u6301docx\u3001html\u3001latex\u4e09\u79cd\u683c\u5f0f\u4e2d\u7684\u4e00\u4e2a\u6216\u591a\u4e2a\n label:\n en_US: Extra export formats\n ja_JP: \u8ffd\u52a0\u306e\u30a8\u30af\u30b9\u30dd\u30fc\u30c8\u5f62\u5f0f\n pt_BR: Extra export formats\n zh_Hans: \u989d\u5916\u5bfc\u51fa\u683c\u5f0f\n llm_description: '(For official API) Example: [\"docx\",\"html\"], markdown,\n json are the default export formats, no need to set, this parameter only\n supports one or more of docx, html, latex'\n max: null\n min: null\n name: extra_formats\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n params:\n enable_formula: ''\n enable_ocr: ''\n enable_table: ''\n extra_formats: ''\n file: ''\n language: ''\n layout_model: ''\n parse_method: ''\n provider_id: langgenius\/mineru\/mineru\n provider_name: langgenius\/mineru\/mineru\n provider_type: builtin\n selected: false\n title: MinerU\n tool_configurations:\n enable_formula:\n type: constant\n value: 1\n enable_ocr:\n type: constant\n value: 0\n enable_table:\n type: constant\n value: 1\n extra_formats:\n type: constant\n value: '[]'\n language:\n type: constant\n value: auto\n layout_model:\n type: constant\n value: doclayout_yolo\n parse_method:\n type: constant\n value: auto\n tool_description: a tool for parsing text, tables, and images, supporting\n multiple formats such as pdf, pptx, docx, etc. supporting multiple languages\n such as English, Chinese, etc.\n tool_label: Parse File\n tool_name: parse-file\n tool_node_version: '2'\n tool_parameters:\n file:\n type: variable\n value:\n - '1750400203722'\n - file\n type: tool\n height: 244\n id: '1751281136356'\n position:\n x: -263.7680017647218\n y: 282\n positionAbsolute:\n x: -263.7680017647218\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n - data:\n is_team_authorization: true\n output_schema:\n properties:\n result:\n description: Parent child chunks result\n items:\n type: object\n type: array\n type: object\n paramSchemas:\n - auto_generate: null\n default: null\n form: llm\n human_description:\n en_US: ''\n ja_JP: ''\n pt_BR: ''\n zh_Hans: ''\n label:\n en_US: Input Content\n ja_JP: Input Content\n pt_BR: Conte\u00fado de Entrada\n zh_Hans: \u8f93\u5165\u6587\u672c\n llm_description: The text you want to chunk.\n max: null\n min: null\n name: input_text\n options: []\n placeholder: null\n precision: null\n required: true\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: paragraph\n form: llm\n human_description:\n en_US: Split text into paragraphs based on separator and maximum chunk\n length, using split text as parent block or entire document as parent\n block and directly retrieve.\n ja_JP: Split text into paragraphs based on separator and maximum chunk\n length, using split text as parent block or entire document as parent\n block and directly retrieve.\n pt_BR: Dividir texto em par\u00e1grafos com base no separador e no comprimento\n m\u00e1ximo do bloco, usando o texto dividido como bloco pai ou documento\n completo como bloco pai e diretamente recuper\u00e1-lo.\n zh_Hans: \u6839\u636e\u5206\u9694\u7b26\u548c\u6700\u5927\u5757\u957f\u5ea6\u5c06\u6587\u672c\u62c6\u5206\u4e3a\u6bb5\u843d\uff0c\u4f7f\u7528\u62c6\u5206\u6587\u672c\u4f5c\u4e3a\u68c0\u7d22\u7684\u7236\u5757\u6216\u6574\u4e2a\u6587\u6863\u7528\u4f5c\u7236\u5757\u5e76\u76f4\u63a5\u68c0\u7d22\u3002\n label:\n en_US: Parent Mode\n ja_JP: Parent Mode\n pt_BR: Modo Pai\n zh_Hans: \u7236\u5757\u6a21\u5f0f\n llm_description: Split text into paragraphs based on separator and maximum\n chunk length, using split text as parent block or entire document as parent\n block and directly retrieve.\n max: null\n min: null\n name: parent_mode\n options:\n - label:\n en_US: Paragraph\n ja_JP: Paragraph\n pt_BR: Par\u00e1grafo\n zh_Hans: \u6bb5\u843d\n value: paragraph\n - label:\n en_US: Full Document\n ja_JP: Full Document\n pt_BR: Documento Completo\n zh_Hans: \u5168\u6587\n value: full_doc\n placeholder: null\n precision: null\n required: true\n scope: null\n template: null\n type: select\n - auto_generate: null\n default: '\n\n\n '\n form: llm\n human_description:\n en_US: Separator used for chunking\n ja_JP: Separator used for chunking\n pt_BR: Separador usado para divis\u00e3o\n zh_Hans: \u7528\u4e8e\u5206\u5757\u7684\u5206\u9694\u7b26\n label:\n en_US: Parent Delimiter\n ja_JP: Parent Delimiter\n pt_BR: Separador de Pai\n zh_Hans: \u7236\u5757\u5206\u9694\u7b26\n llm_description: The separator used to split chunks\n max: null\n min: null\n name: separator\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: 1024\n form: llm\n human_description:\n en_US: Maximum length for chunking\n ja_JP: Maximum length for chunking\n pt_BR: Comprimento m\u00e1ximo para divis\u00e3o\n zh_Hans: \u7528\u4e8e\u5206\u5757\u7684\u6700\u5927\u957f\u5ea6\n label:\n en_US: Maximum Parent Chunk Length\n ja_JP: Maximum Parent Chunk Length\n pt_BR: Comprimento M\u00e1ximo do Bloco Pai\n zh_Hans: \u6700\u5927\u7236\u5757\u957f\u5ea6\n llm_description: Maximum length allowed per chunk\n max: null\n min: null\n name: max_length\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: number\n - auto_generate: null\n default: '. '\n form: llm\n human_description:\n en_US: Separator used for subchunking\n ja_JP: Separator used for subchunking\n pt_BR: Separador usado para subdivis\u00e3o\n zh_Hans: \u7528\u4e8e\u5b50\u5206\u5757\u7684\u5206\u9694\u7b26\n label:\n en_US: Child Delimiter\n ja_JP: Child Delimiter\n pt_BR: Separador de Subdivis\u00e3o\n zh_Hans: \u5b50\u5206\u5757\u5206\u9694\u7b26\n llm_description: The separator used to split subchunks\n max: null\n min: null\n name: subchunk_separator\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: string\n - auto_generate: null\n default: 512\n form: llm\n human_description:\n en_US: Maximum length for subchunking\n ja_JP: Maximum length for subchunking\n pt_BR: Comprimento m\u00e1ximo para subdivis\u00e3o\n zh_Hans: \u7528\u4e8e\u5b50\u5206\u5757\u7684\u6700\u5927\u957f\u5ea6\n label:\n en_US: Maximum Child Chunk Length\n ja_JP: Maximum Child Chunk Length\n pt_BR: Comprimento M\u00e1ximo de Subdivis\u00e3o\n zh_Hans: \u5b50\u5206\u5757\u6700\u5927\u957f\u5ea6\n llm_description: Maximum length allowed per subchunk\n max: null\n min: null\n name: subchunk_max_length\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: number\n - auto_generate: null\n default: 0\n form: llm\n human_description:\n en_US: Whether to remove consecutive spaces, newlines and tabs\n ja_JP: Whether to remove consecutive spaces, newlines and tabs\n pt_BR: Se deve remover espa\u00e7os extras no texto\n zh_Hans: \u662f\u5426\u79fb\u9664\u6587\u672c\u4e2d\u7684\u8fde\u7eed\u7a7a\u683c\u3001\u6362\u884c\u7b26\u548c\u5236\u8868\u7b26\n label:\n en_US: Replace consecutive spaces, newlines and tabs\n ja_JP: Replace consecutive spaces, newlines and tabs\n pt_BR: Substituir espa\u00e7os consecutivos, novas linhas e guias\n zh_Hans: \u66ff\u6362\u8fde\u7eed\u7a7a\u683c\u3001\u6362\u884c\u7b26\u548c\u5236\u8868\u7b26\n llm_description: Whether to remove consecutive spaces, newlines and tabs\n max: null\n min: null\n name: remove_extra_spaces\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n - auto_generate: null\n default: 0\n form: llm\n human_description:\n en_US: Whether to remove URLs and emails in the text\n ja_JP: Whether to remove URLs and emails in the text\n pt_BR: Se deve remover URLs e e-mails no texto\n zh_Hans: \u662f\u5426\u79fb\u9664\u6587\u672c\u4e2d\u7684URL\u548c\u7535\u5b50\u90ae\u4ef6\u5730\u5740\n label:\n en_US: Delete all URLs and email addresses\n ja_JP: Delete all URLs and email addresses\n pt_BR: Remover todas as URLs e e-mails\n zh_Hans: \u5220\u9664\u6240\u6709URL\u548c\u7535\u5b50\u90ae\u4ef6\u5730\u5740\n llm_description: Whether to remove URLs and emails in the text\n max: null\n min: null\n name: remove_urls_emails\n options: []\n placeholder: null\n precision: null\n required: false\n scope: null\n template: null\n type: boolean\n params:\n input_text: ''\n max_length: ''\n parent_mode: ''\n remove_extra_spaces: ''\n remove_urls_emails: ''\n separator: ''\n subchunk_max_length: ''\n subchunk_separator: ''\n provider_id: langgenius\/parentchild_chunker\/parentchild_chunker\n provider_name: langgenius\/parentchild_chunker\/parentchild_chunker\n provider_type: builtin\n selected: false\n title: Parent-child Chunker\n tool_configurations: {}\n tool_description: Process documents into parent-child chunk structures\n tool_label: Parent-child Chunker\n tool_name: parentchild_chunker\n tool_node_version: '2'\n tool_parameters:\n input_text:\n type: mixed\n value: '{{#1751281136356.text#}}'\n max_length:\n type: variable\n value:\n - rag\n - shared\n - Maximum_Parent_Length\n parent_mode:\n type: variable\n value:\n - rag\n - shared\n - Parent_Mode\n remove_extra_spaces:\n type: variable\n value:\n - rag\n - shared\n - clean_1\n remove_urls_emails:\n type: variable\n value:\n - rag\n - shared\n - clean_2\n separator:\n type: mixed\n value: '{{#rag.shared.Parent_Delimiter#}}'\n subchunk_max_length:\n type: variable\n value:\n - rag\n - shared\n - Maximum_Child_Length\n subchunk_separator:\n type: mixed\n value: '{{#rag.shared.Child_Delimiter#}}'\n type: tool\n height: 52\n id: '1751338398711'\n position:\n x: 42.95253988413964\n y: 282\n positionAbsolute:\n x: 42.95253988413964\n y: 282\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 242\n viewport:\n x: 628.3302331655243\n y: 120.08894361588159\n zoom: 0.7027501395646496\n rag_pipeline_variables:\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: paragraph\n label: Parent Mode\n max_length: 48\n options:\n - paragraph\n - full_doc\n placeholder: null\n required: true\n tooltips: 'Parent Mode provides two options: paragraph mode splits text into paragraphs\n as parent chunks for retrieval, while full_doc mode uses the entire document\n as a single parent chunk (text beyond 10,000 tokens will be truncated).'\n type: select\n unit: null\n variable: Parent_Mode\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: \\n\\n\n label: Parent Delimiter\n max_length: 48\n options: []\n placeholder: null\n required: false\n tooltips: A delimiter is the character used to separate text. \\n\\n is recommended\n for splitting the original document into large parent chunks. You can also use\n special delimiters defined by yourself.\n type: text-input\n unit: null\n variable: Parent_Delimiter\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: 1024\n label: Maximum Parent Length\n max_length: 48\n options: []\n placeholder: null\n required: false\n tooltips: null\n type: number\n unit: tokens\n variable: Maximum_Parent_Length\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: \\n\n label: Child Delimiter\n max_length: 48\n options: []\n placeholder: null\n required: true\n tooltips: A delimiter is the character used to separate text. \\n is recommended\n for splitting parent chunks into small child chunks. You can also use special\n delimiters defined by yourself.\n type: text-input\n unit: null\n variable: Child_Delimiter\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: 256\n label: Maximum Child Length\n max_length: 48\n options: []\n placeholder: null\n required: true\n tooltips: null\n type: number\n unit: tokens\n variable: Maximum_Child_Length\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: true\n label: Replace consecutive spaces, newlines and tabs.\n max_length: 48\n options: []\n placeholder: null\n required: true\n tooltips: null\n type: checkbox\n unit: null\n variable: clean_1\n - allow_file_extension: null\n allow_file_upload_methods: null\n allowed_file_types: null\n belong_to_node_id: shared\n default_value: null\n label: Delete all URLs and email addresses.\n max_length: 48\n options: []\n placeholder: null\n required: false\n tooltips: null\n type: checkbox\n unit: null\n variable: clean_2\n", @@ -7340,4 +7372,4 @@ "name": "Complex PDF with Images & Tables" } } -} \ No newline at end of file +} diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 619b80ff28..f37598fb31 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -15,11 +15,11 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.variables.segment_group import SegmentGroup -from core.variables.segments import ArrayFileSegment, FileSegment, Segment -from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.file import helpers as file_helpers +from core.workflow.variables.segment_group import SegmentGroup +from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment +from core.workflow.variables.types import SegmentType from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 2911b1cf18..7e285c8da9 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variables.types import SegmentType from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index fd928b077d..014f4c4132 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -36,9 +36,9 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data" ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code" -def account_initialization_required(view: Callable[P, R]): +def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]: @wraps(view) - def decorated(*args: P.args, **kwargs: P.kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # check account initialization current_user, _ = current_account_with_tenant() if current_user.status == AccountStatus.UNINITIALIZED: @@ -214,9 +214,9 @@ def cloud_utm_record(view: Callable[P, R]): return decorated -def setup_required(view: Callable[P, R]): +def setup_required(view: Callable[P, R]) -> Callable[P, R]: @wraps(view) - def decorated(*args: P.args, **kwargs: P.kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # check setup if ( dify_config.EDITION == "SELF_HOSTED" diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 8b20442eab..18ae75a087 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -25,7 +25,6 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration -from core.variables.variables import Variable from core.workflow.enums import WorkflowType from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.layers.base import GraphEngineLayer @@ -34,6 +33,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader +from core.workflow.variables.variables import Variable from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 00a6a3d9af..534ef6994a 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -669,16 +669,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) -> Generator[StreamResponse, None, None]: """Handle retriever resources events.""" self._message_cycle_manager.handle_retriever_resources(event) - return - yield # Make this a generator + yield from () def _handle_annotation_reply_event( self, event: QueueAnnotationReplyEvent, **kwargs ) -> Generator[StreamResponse, None, None]: """Handle annotation reply events.""" self._message_cycle_manager.handle_annotation_reply(event) - return - yield # Make this a generator + yield from () def _handle_message_replace_event( self, event: QueueMessageReplaceEvent, **kwargs diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index d2f09a25c3..af1f1d7c66 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -122,7 +122,7 @@ class AppQueueManager(ABC): """Attach the live graph runtime state reference for downstream consumers.""" self._graph_runtime_state = graph_runtime_state - def publish(self, event: AppQueueEvent, pub_from: PublishFrom): + def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ Publish event to queue :param event: diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 510abdc1d0..d4e801de13 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -49,7 +49,6 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.trigger_manager import TriggerManager -from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.entities.workflow_start_reason import WorkflowStartReason from core.workflow.enums import ( @@ -62,6 +61,7 @@ from core.workflow.enums import ( from core.workflow.file import FILE_MODEL_IDENTITY, File from core.workflow.runtime import GraphRuntimeState from core.workflow.system_variable import SystemVariable +from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_database import db diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 8ea34344b2..02caf8f511 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -11,7 +11,6 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.app.workflow.node_factory import DifyNodeFactory -from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.enums import WorkflowType from core.workflow.graph import Graph @@ -21,6 +20,7 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader +from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.dataset import Document, Pipeline diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index c070845b73..a748d90387 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,12 @@ import logging -from core.variables import VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.enums import NodeType from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent from core.workflow.nodes.variable_assigner.common import helpers as common_helpers +from core.workflow.variables import VariableBase logger = logging.getLogger(__name__) diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index 2b162920ee..ebae830389 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -83,14 +83,21 @@ def fetch_model_config( raise ModelNotExistError(f"Model {node_data_model.name} not exist.") provider_model.raise_for_status() - stop: list[str] = [] - if "stop" in node_data_model.completion_params: - stop = node_data_model.completion_params.pop("stop") + completion_params = dict(node_data_model.completion_params) + stop = completion_params.pop("stop", []) + if not isinstance(stop, list): + stop = [] model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) if not model_schema: raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + model_instance.provider = node_data_model.provider + model_instance.model_name = node_data_model.name + model_instance.credentials = credentials + model_instance.parameters = completion_params + model_instance.stop = tuple(stop) + return model_instance, ModelConfigWithCredentialsEntity( provider=node_data_model.provider, model=node_data_model.name, @@ -98,6 +105,6 @@ def fetch_model_config( mode=node_data_model.mode, provider_model_bundle=provider_model_bundle, credentials=credentials, - parameters=node_data_model.completion_params, + parameters=completion_params, stop=stop, ) diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 07dec1b070..41b8c9fd7b 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -1,13 +1,20 @@ from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, final +from typing import TYPE_CHECKING, Any, cast, final from typing_extensions import override from configs import dify_config from core.app.llm.model_access import build_dify_model_access -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor -from core.helper.code_executor.code_node_provider import CodeNodeProvider +from core.datasource.datasource_manager import DatasourceManager +from core.helper.code_executor.code_executor import ( + CodeExecutionError, + CodeExecutor, +) from core.helper.ssrf_proxy import ssrf_proxy +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.graph_config import NodeConfigDict @@ -18,10 +25,15 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor from core.workflow.nodes.code.entities import CodeLanguage from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.nodes.datasource import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm import llm_utils +from core.workflow.nodes.llm.entities import ModelConfig +from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm.protocols import PromptMessageMemory from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode @@ -70,7 +82,6 @@ class DifyNodeFactory(NodeFactory): self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor() - self._code_providers: tuple[type[CodeNodeProvider], ...] = CodeNode.default_code_providers() self._code_limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, @@ -142,7 +153,6 @@ class DifyNodeFactory(NodeFactory): graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, code_executor=self._code_executor, - code_providers=self._code_providers, code_limits=self._code_limits, ) @@ -169,6 +179,8 @@ class DifyNodeFactory(NodeFactory): ) if node_type == NodeType.LLM: + model_instance = self._build_model_instance_for_llm_node(node_data) + memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) return LLMNode( id=node_id, config=node_config, @@ -176,6 +188,17 @@ class DifyNodeFactory(NodeFactory): graph_runtime_state=self.graph_runtime_state, credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, + model_instance=model_instance, + memory=memory, + ) + + if node_type == NodeType.DATASOURCE: + return DatasourceNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + datasource_manager=DatasourceManager, ) if node_type == NodeType.KNOWLEDGE_RETRIEVAL: @@ -197,6 +220,7 @@ class DifyNodeFactory(NodeFactory): ) if node_type == NodeType.QUESTION_CLASSIFIER: + model_instance = self._build_model_instance_for_llm_node(node_data) return QuestionClassifierNode( id=node_id, config=node_config, @@ -204,9 +228,11 @@ class DifyNodeFactory(NodeFactory): graph_runtime_state=self.graph_runtime_state, credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, + model_instance=model_instance, ) if node_type == NodeType.PARAMETER_EXTRACTOR: + model_instance = self._build_model_instance_for_llm_node(node_data) return ParameterExtractorNode( id=node_id, config=node_config, @@ -214,6 +240,7 @@ class DifyNodeFactory(NodeFactory): graph_runtime_state=self.graph_runtime_state, credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, + model_instance=model_instance, ) return node_class( @@ -222,3 +249,55 @@ class DifyNodeFactory(NodeFactory): graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, ) + + def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance: + node_data_model = ModelConfig.model_validate(node_data["model"]) + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") + + credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) + model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) + provider_model_bundle = model_instance.provider_model_bundle + + provider_model = provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, + model_type=ModelType.LLM, + ) + if provider_model is None: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() + + completion_params = dict(node_data_model.completion_params) + stop = completion_params.pop("stop", []) + if not isinstance(stop, list): + stop = [] + + model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) + if not model_schema: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + model_instance.provider = node_data_model.provider + model_instance.model_name = node_data_model.name + model_instance.credentials = credentials + model_instance.parameters = completion_params + model_instance.stop = tuple(stop) + model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) + return model_instance + + def _build_memory_for_llm_node( + self, + *, + node_data: Mapping[str, Any], + model_instance: ModelInstance, + ) -> PromptMessageMemory | None: + raw_memory_config = node_data.get("memory") + if raw_memory_config is None: + return None + + node_memory = MemoryConfig.model_validate(raw_memory_config) + return llm_utils.fetch_memory( + variable_pool=self.graph_runtime_state.variable_pool, + app_id=self.graph_init_params.app_id, + node_data_memory=node_memory, + model_instance=model_instance, + ) diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 002415a7db..9c48f755a9 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -1,16 +1,39 @@ import logging +from collections.abc import Generator from threading import Lock +from typing import Any, cast + +from sqlalchemy import select import contexts from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController -from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.datasource.entities.datasource_entities import ( + DatasourceMessage, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + OnlineDriveDownloadFileRequest, +) from core.datasource.errors import DatasourceProviderNotFoundError from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController +from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.enums import WorkflowNodeExecutionMetadataKey +from core.workflow.file import File +from core.workflow.file.enums import FileTransferMethod, FileType +from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam +from factories import file_factory +from models.model import UploadFile +from models.tools import ToolFile +from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -103,3 +126,238 @@ class DatasourceManager: tenant_id, datasource_type, ).get_datasource(datasource_name) + + @classmethod + def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: + datasource_runtime = cls.get_datasource_runtime( + provider_id=provider_id, + datasource_name=datasource_name, + tenant_id=tenant_id, + datasource_type=DatasourceProviderType.value_of(datasource_type), + ) + return datasource_runtime.get_icon_url(tenant_id) + + @classmethod + def stream_online_results( + cls, + *, + user_id: str, + datasource_name: str, + datasource_type: str, + provider_id: str, + tenant_id: str, + provider: str, + plugin_id: str, + credential_id: str, + datasource_param: DatasourceParameter | None = None, + online_drive_request: OnlineDriveDownloadFileParam | None = None, + ) -> Generator[DatasourceMessage, None, Any]: + """ + Pull-based streaming of domain messages from datasource plugins. + Returns a generator that yields DatasourceMessage and finally returns a minimal final payload. + Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly. + """ + ds_type = DatasourceProviderType.value_of(datasource_type) + runtime = cls.get_datasource_runtime( + provider_id=provider_id, + datasource_name=datasource_name, + tenant_id=tenant_id, + datasource_type=ds_type, + ) + + dsp_service = DatasourceProviderService() + credentials = dsp_service.get_datasource_credentials( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + credential_id=credential_id, + ) + + if ds_type == DatasourceProviderType.ONLINE_DOCUMENT: + doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime) + if credentials: + doc_runtime.runtime.credentials = credentials + if datasource_param is None: + raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming") + inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content( + user_id=user_id, + datasource_parameters=GetOnlineDocumentPageContentRequest( + workspace_id=datasource_param.workspace_id, + page_id=datasource_param.page_id, + type=datasource_param.type, + ), + provider_type=ds_type, + ) + elif ds_type == DatasourceProviderType.ONLINE_DRIVE: + drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime) + if credentials: + drive_runtime.runtime.credentials = credentials + if online_drive_request is None: + raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming") + inner_gen = drive_runtime.online_drive_download_file( + user_id=user_id, + request=OnlineDriveDownloadFileRequest( + id=online_drive_request.id, + bucket=online_drive_request.bucket, + ), + provider_type=ds_type, + ) + else: + raise ValueError(f"Unsupported datasource type for streaming: {ds_type}") + + # Bridge through to caller while preserving generator return contract + yield from inner_gen + # No structured final data here; node/adapter will assemble outputs + return {} + + @classmethod + def stream_node_events( + cls, + *, + node_id: str, + user_id: str, + datasource_name: str, + datasource_type: str, + provider_id: str, + tenant_id: str, + provider: str, + plugin_id: str, + credential_id: str, + parameters_for_log: dict[str, Any], + datasource_info: dict[str, Any], + variable_pool: Any, + datasource_param: DatasourceParameter | None = None, + online_drive_request: OnlineDriveDownloadFileParam | None = None, + ) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: + ds_type = DatasourceProviderType.value_of(datasource_type) + + messages = cls.stream_online_results( + user_id=user_id, + datasource_name=datasource_name, + datasource_type=datasource_type, + provider_id=provider_id, + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + credential_id=credential_id, + datasource_param=datasource_param, + online_drive_request=online_drive_request, + ) + + transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None + ) + + variables: dict[str, Any] = {} + file_out: File | None = None + + for message in transformed: + mtype = message.type + if mtype in { + DatasourceMessage.MessageType.IMAGE_LINK, + DatasourceMessage.MessageType.BINARY_LINK, + DatasourceMessage.MessageType.IMAGE, + }: + wanted_ds_type = ds_type in { + DatasourceProviderType.ONLINE_DRIVE, + DatasourceProviderType.ONLINE_DOCUMENT, + } + if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage): + url = message.message.text + + datasource_file_id = str(url).split("/")[-1].split(".")[0] + with session_factory.create_session() as session: + stmt = select(ToolFile).where( + ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id + ) + datasource_file = session.scalar(stmt) + if not datasource_file: + raise ValueError( + f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}" + ) + mime_type = datasource_file.mimetype + if datasource_file is not None: + mapping = { + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(mime_type), + "transfer_method": FileTransferMethod.TOOL_FILE, + "url": url, + } + file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id) + elif mtype == DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False) + elif mtype == DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + yield StreamChunkEvent( + selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False + ) + elif mtype == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + name = message.message.variable_name + value = message.message.variable_value + if message.message.stream: + assert isinstance(value, str), "stream variable_value must be str" + variables[name] = variables.get(name, "") + value + yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False) + else: + variables[name] = value + elif mtype == DatasourceMessage.MessageType.FILE: + if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta: + f = message.meta.get("file") + if isinstance(f, File): + file_out = f + else: + pass + + yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True) + + if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None: + variable_pool.add([node_id, "file"], file_out) + + if ds_type == DatasourceProviderType.ONLINE_DOCUMENT: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={**variables}, + ) + ) + else: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "file": file_out, + "datasource_type": ds_type, + }, + ) + ) + + @classmethod + def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: + with session_factory.create_session() as session: + upload_file = ( + session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first() + ) + if not upload_file: + raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") + + file_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=tenant_id, + type=FileType.CUSTOM, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=upload_file.source_url, + ) + return file_info diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index dde7d59726..a063a3680b 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -379,4 +379,11 @@ class OnlineDriveDownloadFileRequest(BaseModel): """ id: str = Field(..., description="The id of the file") - bucket: str | None = Field(None, description="The name of the bucket") + bucket: str = Field("", description="The name of the bucket") + + @field_validator("bucket", mode="before") + @classmethod + def _coerce_bucket(cls, v) -> str: + if v is None: + return "" + return str(v) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 73174ed28d..d581b3ac39 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,6 +1,5 @@ import logging from collections.abc import Mapping -from enum import StrEnum from threading import Lock from typing import Any @@ -14,6 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client +from core.workflow.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) @@ -40,12 +40,6 @@ class CodeExecutionResponse(BaseModel): data: Data -class CodeLanguage(StrEnum): - PYTHON3 = "python3" - JINJA2 = "jinja2" - JAVASCRIPT = "javascript" - - def _build_code_executor_client() -> httpx.Client: return httpx.Client( verify=CODE_EXECUTION_SSL_VERIFY, diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 5cdea19a8d..1b56eaba21 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from core.variables.utils import dumps_with_segments +from core.workflow.variables.utils import dumps_with_segments class TemplateTransformer(ABC): diff --git a/api/core/model_manager.py b/api/core/model_manager.py index ac096c5e54..2b3a3be1b9 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable, Generator, Iterable, Sequence +from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload from configs import dify_config @@ -38,6 +38,9 @@ class ModelInstance: self.model_name = model self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + # Runtime LLM invocation fields. + self.parameters: Mapping[str, Any] = {} + self.stop: Sequence[str] = () self.model_type_instance = self.provider_model_bundle.model_type_instance self.load_balancing_manager = self._get_load_balancing_manager( configuration=provider_model_bundle.configuration, diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index bbbdec61d1..c32ab0879e 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -83,19 +83,21 @@ def _merge_tool_call_delta( tool_call.function.arguments += delta.function.arguments -def _build_llm_result_from_first_chunk( +def _build_llm_result_from_chunks( model: str, prompt_messages: Sequence[PromptMessage], chunks: Iterator[LLMResultChunk], ) -> LLMResult: """ - Build a single `LLMResult` from the first returned chunk. + Build a single `LLMResult` by accumulating all returned chunks. - This is used for `stream=False` because the plugin side may still implement the response via a chunked stream. + Some models only support streaming output (e.g. Qwen3 open-source edition) + and the plugin side may still implement the response via a chunked stream, + so all chunks must be consumed and concatenated into a single ``LLMResult``. - Note: - This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying - streaming resources are released (e.g., HTTP connections owned by the plugin runtime). + The ``usage`` is taken from the last chunk that carries it, which is the + typical convention for streaming responses (the final chunk contains the + aggregated token counts). """ content = "" content_list: list[PromptMessageContentUnionTypes] = [] @@ -104,24 +106,27 @@ def _build_llm_result_from_first_chunk( tools_calls: list[AssistantPromptMessage.ToolCall] = [] try: - first_chunk = next(chunks, None) - if first_chunk is not None: - if isinstance(first_chunk.delta.message.content, str): - content += first_chunk.delta.message.content - elif isinstance(first_chunk.delta.message.content, list): - content_list.extend(first_chunk.delta.message.content) + for chunk in chunks: + if isinstance(chunk.delta.message.content, str): + content += chunk.delta.message.content + elif isinstance(chunk.delta.message.content, list): + content_list.extend(chunk.delta.message.content) - if first_chunk.delta.message.tool_calls: - _increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls) + if chunk.delta.message.tool_calls: + _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - usage = first_chunk.delta.usage or LLMUsage.empty_usage() - system_fingerprint = first_chunk.system_fingerprint + if chunk.delta.usage: + usage = chunk.delta.usage + if chunk.system_fingerprint: + system_fingerprint = chunk.system_fingerprint + except Exception: + logger.exception("Error while consuming non-stream plugin chunk iterator.") + raise finally: - try: - for _ in chunks: - pass - except Exception: - logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True) + # Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections). + close = getattr(chunks, "close", None) + if callable(close): + close() return LLMResult( model=model, @@ -174,7 +179,7 @@ def _normalize_non_stream_plugin_result( ) -> LLMResult: if isinstance(result, LLMResult): return result - return _build_llm_result_from_first_chunk(model=model, prompt_messages=prompt_messages, chunks=result) + return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result) def _increase_tool_call( diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index 04b46d67a8..8c081ae225 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -14,10 +14,9 @@ class BaseTraceInstance(ABC): Base trace instance for ops trace services """ - @abstractmethod def __init__(self, trace_config: BaseTracingConfig): """ - Abstract initializer for the trace instance. + Initializer for the trace instance. Distribute trace tasks by matching entities """ self.trace_config = trace_config diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 549e428f88..177991e645 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -41,8 +41,8 @@ logger = logging.getLogger(__name__) class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, provider: str) -> dict[str, Any]: - match provider: + def __getitem__(self, key: str) -> dict[str, Any]: + match key: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -149,7 +149,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): } case _: - raise KeyError(f"Unsupported tracing provider: {provider}") + raise KeyError(f"Unsupported tracing provider: {key}") provider_config_map = OpsTraceProviderConfigMap() diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index fd1b7d838c..771b6be332 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -4,6 +4,7 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, @@ -44,7 +45,8 @@ class AdvancedPromptTransform(PromptTransform): context: str | None, memory_config: MemoryConfig | None, memory: TokenBufferMemory | None, - model_config: ModelConfigWithCredentialsEntity, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: prompt_messages = [] @@ -59,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, memory=memory, model_config=model_config, + model_instance=model_instance, image_detail_config=image_detail_config, ) elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template): @@ -71,6 +74,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, memory=memory, model_config=model_config, + model_instance=model_instance, image_detail_config=image_detail_config, ) @@ -85,7 +89,8 @@ class AdvancedPromptTransform(PromptTransform): context: str | None, memory_config: MemoryConfig | None, memory: TokenBufferMemory | None, - model_config: ModelConfigWithCredentialsEntity, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ @@ -111,6 +116,7 @@ class AdvancedPromptTransform(PromptTransform): parser=parser, prompt_inputs=prompt_inputs, model_config=model_config, + model_instance=model_instance, ) if query: @@ -146,7 +152,8 @@ class AdvancedPromptTransform(PromptTransform): context: str | None, memory_config: MemoryConfig | None, memory: TokenBufferMemory | None, - model_config: ModelConfigWithCredentialsEntity, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ @@ -198,8 +205,13 @@ class AdvancedPromptTransform(PromptTransform): prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if memory and memory_config: - prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) - + prompt_messages = self._append_chat_histories( + memory, + memory_config, + prompt_messages, + model_config=model_config, + model_instance=model_instance, + ) if files and query is not None: for file in files: prompt_message_contents.append( @@ -276,7 +288,8 @@ class AdvancedPromptTransform(PromptTransform): role_prefix: MemoryConfig.RolePrefix, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str], - model_config: ModelConfigWithCredentialsEntity, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, ) -> Mapping[str, str]: prompt_inputs = dict(prompt_inputs) if "#histories#" in parser.variable_keys: @@ -286,7 +299,11 @@ class AdvancedPromptTransform(PromptTransform): prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs)) - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + rest_tokens = self._calculate_rest_token( + [tmp_human_message], + model_config=model_config, + model_instance=model_instance, + ) histories = self._get_history_messages_from_memory( memory=memory, diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 2b32062140..c1ae47709f 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -41,7 +41,7 @@ class AgentHistoryPromptTransform(PromptTransform): if not self.memory: return prompt_messages - max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) + max_token_limit = self._calculate_rest_token(self.prompt_messages, model_config=self.model_config) model_type_instance = self.model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index a6e873d587..22ef5809bb 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -4,45 +4,83 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: + def _resolve_model_runtime( + self, + *, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, + ) -> tuple[ModelInstance, AIModelEntity]: + if model_instance is None: + if model_config is None: + raise ValueError("Either model_config or model_instance must be provided.") + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + model_instance.credentials = model_config.credentials + model_instance.parameters = model_config.parameters + model_instance.stop = model_config.stop + + model_schema = model_instance.model_type_instance.get_model_schema( + model=model_instance.model_name, + credentials=model_instance.credentials, + ) + if model_schema is None: + if model_config is None: + raise ValueError("Model schema not found for the provided model instance.") + model_schema = model_config.model_schema + + return model_instance, model_schema + def _append_chat_histories( self, memory: TokenBufferMemory, memory_config: MemoryConfig, prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity, + *, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, ) -> list[PromptMessage]: - rest_tokens = self._calculate_rest_token(prompt_messages, model_config) + rest_tokens = self._calculate_rest_token( + prompt_messages, + model_config=model_config, + model_instance=model_instance, + ) histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages def _calculate_rest_token( - self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + self, + prompt_messages: list[PromptMessage], + *, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, ) -> int: + model_instance, model_schema = self._resolve_model_runtime( + model_config=model_config, + model_instance=model_instance, + ) + model_parameters = model_instance.parameters rest_tokens = 2000 - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: + for parameter_rule in model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") + model_parameters.get(parameter_rule.name) + or model_parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d6abbaaa69..936a093488 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -252,7 +252,7 @@ class SimplePromptTransform(PromptTransform): if memory: tmp_human_message = UserPromptMessage(content=prompt) - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config=model_config) histories = self._get_history_messages_from_memory( memory=memory, memory_config=MemoryConfig( diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 6e76321ea0..e8b3fa1508 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -75,15 +75,15 @@ class BaseIndexProcessor(ABC): multimodal_documents: list[AttachmentDocument] | None = None, with_keywords: bool = True, **kwargs, - ): + ) -> None: raise NotImplementedError @abstractmethod - def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None: raise NotImplementedError @abstractmethod - def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: raise NotImplementedError @abstractmethod diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 3b42560fd6..cfeee4afc7 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -115,7 +115,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): multimodal_documents: list[AttachmentDocument] | None = None, with_keywords: bool = True, **kwargs, - ): + ) -> None: if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) @@ -130,7 +130,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): else: keyword.add_texts(documents) - def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None: # Note: Summary indexes are now disabled (not deleted) when segments are disabled. # This method is called for actual deletion scenarios (e.g., when segment is deleted). # For disable operations, disable_summaries_for_segments is called directly in the task. @@ -196,7 +196,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs - def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: documents: list[Any] = [] all_multimodal_documents: list[Any] = [] if isinstance(chunks, list): @@ -469,7 +469,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if not isinstance(result, LLMResult): raise ValueError("Expected LLMResult when stream=False") - summary_content = getattr(result.message, "content", "") + summary_content = result.message.get_text_content() usage = result.usage # Deduct quota for summary generation (same as workflow nodes) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 0ea77405ed..367f0aec00 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -126,7 +126,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): multimodal_documents: list[AttachmentDocument] | None = None, with_keywords: bool = True, **kwargs, - ): + ) -> None: if dataset.indexing_technique == "high_quality": vector = Vector(dataset) for document in documents: @@ -139,7 +139,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if multimodal_documents and dataset.is_multimodal: vector.create_multimodal(multimodal_documents) - def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None: # node_ids is segment's node_ids # Note: Summary indexes are now disabled (not deleted) when segments are disabled. # This method is called for actual deletion scenarios (e.g., when segment is deleted). @@ -272,7 +272,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_nodes.append(child_document) return child_nodes - def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: parent_childs = ParentChildStructureChunk.model_validate(chunks) documents = [] for parent_child in parent_childs.parent_child_chunks: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 40d9caaa69..503cce2132 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -139,14 +139,14 @@ class QAIndexProcessor(BaseIndexProcessor): multimodal_documents: list[AttachmentDocument] | None = None, with_keywords: bool = True, **kwargs, - ): + ) -> None: if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if multimodal_documents and dataset.is_multimodal: vector.create_multimodal(multimodal_documents) - def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): + def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs) -> None: # Note: Summary indexes are now disabled (not deleted) when segments are disabled. # This method is called for actual deletion scenarios (e.g., when segment is deleted). # For disable operations, disable_summaries_for_segments is called directly in the task. @@ -206,7 +206,7 @@ class QAIndexProcessor(BaseIndexProcessor): docs.append(doc) return docs - def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any) -> None: qa_chunks = QAStructureChunk.model_validate(chunks) documents = [] for qa_chunk in qa_chunks.qa_chunks: diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py index 75f47691da..6bfb2b2880 100644 --- a/api/core/workflow/conversation_variable_updater.py +++ b/api/core/workflow/conversation_variable_updater.py @@ -1,7 +1,7 @@ import abc from typing import Protocol -from core.variables import VariableBase +from core.workflow.variables import VariableBase class ConversationVariableUpdater(Protocol): diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 41276eb444..7e7b65247b 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -11,7 +11,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.variables.variables import Variable +from core.workflow.variables.variables import Variable class CommandType(StrEnum): diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5c39a67102..ac86b1784f 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -25,7 +25,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.variables.segments import ArrayFileSegment, StringSegment from core.workflow.enums import ( NodeType, SystemVariableKey, @@ -44,6 +43,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool +from core.workflow.variables.segments import ArrayFileSegment, StringSegment from extensions.ext_database import db from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index d3b3fac107..388447368e 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,13 +1,13 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.variables import ArrayFileSegment, FileSegment, Segment from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.answer.entities import AnswerNodeData from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.variables import ArrayFileSegment, FileSegment, Segment class AnswerNode(Node[AnswerNodeData]): diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index f7a6c41f0a..7b1cbfcfea 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,17 +1,15 @@ from collections.abc import Mapping, Sequence from decimal import Decimal -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast +from textwrap import dedent +from typing import TYPE_CHECKING, Any, Protocol, cast -from core.helper.code_executor.code_node_provider import CodeNodeProvider -from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider -from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.variables.segments import ArrayFileSegment -from core.variables.types import SegmentType from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.variables.segments import ArrayFileSegment +from core.workflow.variables.types import SegmentType from .exc import ( CodeNodeError, @@ -36,12 +34,44 @@ class WorkflowCodeExecutor(Protocol): def is_execution_error(self, error: Exception) -> bool: ... +def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]: + return { + "type": "code", + "config": { + "variables": [ + {"variable": "arg1", "value_selector": []}, + {"variable": "arg2", "value_selector": []}, + ], + "code_language": language, + "code": code, + "outputs": {"result": {"type": "string", "children": None}}, + }, + } + + +_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { + CodeLanguage.PYTHON3: dedent( + """ + def main(arg1: str, arg2: str): + return { + "result": arg1 + arg2, + } + """ + ), + CodeLanguage.JAVASCRIPT: dedent( + """ + function main({arg1, arg2}) { + return { + result: arg1 + arg2 + } + } + """ + ), +} + + class CodeNode(Node[CodeNodeData]): node_type = NodeType.CODE - _DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = ( - Python3CodeProvider, - JavascriptCodeProvider, - ) _limits: CodeNodeLimits def __init__( @@ -52,7 +82,6 @@ class CodeNode(Node[CodeNodeData]): graph_runtime_state: "GraphRuntimeState", *, code_executor: WorkflowCodeExecutor, - code_providers: Sequence[type[CodeNodeProvider]] | None = None, code_limits: CodeNodeLimits, ) -> None: super().__init__( @@ -62,9 +91,6 @@ class CodeNode(Node[CodeNodeData]): graph_runtime_state=graph_runtime_state, ) self._code_executor: WorkflowCodeExecutor = code_executor - self._code_providers: tuple[type[CodeNodeProvider], ...] = ( - tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS - ) self._limits = code_limits @classmethod @@ -78,15 +104,10 @@ class CodeNode(Node[CodeNodeData]): if filters: code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - code_provider: type[CodeNodeProvider] = next( - provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language) - ) - - return code_provider.get_default_config() - - @classmethod - def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]: - return cls._DEFAULT_CODE_PROVIDERS + default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language) + if default_code is None: + raise CodeNodeError(f"Unsupported code language: {code_language}") + return _build_default_config(language=code_language, code=default_code) @classmethod def version(cls) -> str: @@ -108,7 +129,6 @@ class CodeNode(Node[CodeNodeData]): variables[variable_name] = variable.to_object() if variable else None # Run code try: - _ = self._select_code_provider(code_language) result = self._code_executor.execute( language=code_language, code=code, @@ -130,12 +150,6 @@ class CodeNode(Node[CodeNodeData]): return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]: - for provider in self._code_providers: - if provider.is_accept_language(code_language): - return provider - raise CodeNodeError(f"Unsupported code language: {code_language}") - def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 8026011196..8b73b89e2f 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,11 +1,18 @@ +from enum import StrEnum from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel -from core.helper.code_executor.code_executor import CodeLanguage -from core.variables.types import SegmentType from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.entities import VariableSelector +from core.workflow.variables.types import SegmentType + + +class CodeLanguage(StrEnum): + PYTHON3 = "python3" + JINJA2 = "jinja2" + JAVASCRIPT = "javascript" + _ALLOWED_OUTPUT_FROM_CODE = frozenset( [ diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 80869ac7f7..17f8bcb2db 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,40 +1,26 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import TYPE_CHECKING, Any -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.datasource.entities.datasource_entities import ( - DatasourceMessage, - DatasourceParameter, - DatasourceProviderType, - GetOnlineDocumentPageContentRequest, - OnlineDriveDownloadFileRequest, -) -from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin -from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin -from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError -from core.variables.segments import ArrayAnySegment -from core.variables.variables import ArrayAnyVariable from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey -from core.workflow.file import File -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.node_events import NodeRunResult, StreamCompletedEvent from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.tool.exc import ToolFileError -from core.workflow.runtime import VariablePool -from extensions.ext_database import db -from factories import file_factory -from models.model import UploadFile -from models.tools import ToolFile -from services.datasource_provider_service import DatasourceProviderService +from core.workflow.repositories.datasource_manager_protocol import ( + DatasourceManagerProtocol, + DatasourceParameter, + OnlineDriveDownloadFileParam, +) from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from .entities import DatasourceNodeData -from .exc import DatasourceNodeError, DatasourceParameterError +from .exc import DatasourceNodeError + +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState class DatasourceNode(Node[DatasourceNodeData]): @@ -45,6 +31,22 @@ class DatasourceNode(Node[DatasourceNodeData]): node_type = NodeType.DATASOURCE execution_type = NodeExecutionType.ROOT + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + datasource_manager: DatasourceManagerProtocol, + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self.datasource_manager = datasource_manager + def _run(self) -> Generator: """ Run the datasource node @@ -52,84 +54,69 @@ class DatasourceNode(Node[DatasourceNodeData]): node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) - if not datasource_type_segement: + datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) + if not datasource_type_segment: raise DatasourceNodeError("Datasource type is not set") - datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None - datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) - if not datasource_info_segement: + datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None + datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) + if not datasource_info_segment: raise DatasourceNodeError("Datasource info is not set") - datasource_info_value = datasource_info_segement.value + datasource_info_value = datasource_info_segment.value if not isinstance(datasource_info_value, dict): raise DatasourceNodeError("Invalid datasource info format") datasource_info: dict[str, Any] = datasource_info_value - # get datasource runtime - from core.datasource.datasource_manager import DatasourceManager if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") datasource_type = DatasourceProviderType.value_of(datasource_type) + provider_id = f"{node_data.plugin_id}/{node_data.provider_name}" - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", + datasource_info["icon"] = self.datasource_manager.get_icon_url( + provider_id=provider_id, datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, - datasource_type=datasource_type, + datasource_type=datasource_type.value, ) - datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) parameters_for_log = datasource_info try: - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_datasource_credentials( - tenant_id=self.tenant_id, - provider=node_data.provider_name, - plugin_id=node_data.plugin_id, - credential_id=datasource_info.get("credential_id", ""), - ) match datasource_type: - case DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - if credentials: - datasource_runtime.runtime.credentials = credentials - online_document_result: Generator[DatasourceMessage, None, None] = ( - datasource_runtime.get_online_document_page_content( - user_id=self.user_id, - datasource_parameters=GetOnlineDocumentPageContentRequest( - workspace_id=datasource_info.get("workspace_id", ""), - page_id=datasource_info.get("page", {}).get("page_id", ""), - type=datasource_info.get("page", {}).get("type", ""), - ), - provider_type=datasource_type, + case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE: + # Build typed request objects + datasource_parameters = None + if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT: + datasource_parameters = DatasourceParameter( + workspace_id=datasource_info.get("workspace_id", ""), + page_id=datasource_info.get("page", {}).get("page_id", ""), + type=datasource_info.get("page", {}).get("type", ""), ) - ) - yield from self._transform_message( - messages=online_document_result, - parameters_for_log=parameters_for_log, - datasource_info=datasource_info, - ) - case DatasourceProviderType.ONLINE_DRIVE: - datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) - if credentials: - datasource_runtime.runtime.credentials = credentials - online_drive_result: Generator[DatasourceMessage, None, None] = ( - datasource_runtime.online_drive_download_file( - user_id=self.user_id, - request=OnlineDriveDownloadFileRequest( - id=datasource_info.get("id", ""), - bucket=datasource_info.get("bucket"), - ), - provider_type=datasource_type, + + online_drive_request = None + if datasource_type == DatasourceProviderType.ONLINE_DRIVE: + online_drive_request = OnlineDriveDownloadFileParam( + id=datasource_info.get("id", ""), + bucket=datasource_info.get("bucket", ""), ) - ) - yield from self._transform_datasource_file_message( - messages=online_drive_result, + + credential_id = datasource_info.get("credential_id", "") + + yield from self.datasource_manager.stream_node_events( + node_id=self._node_id, + user_id=self.user_id, + datasource_name=node_data.datasource_name or "", + datasource_type=datasource_type.value, + provider_id=provider_id, + tenant_id=self.tenant_id, + provider=node_data.provider_name, + plugin_id=node_data.plugin_id, + credential_id=credential_id, parameters_for_log=parameters_for_log, datasource_info=datasource_info, variable_pool=variable_pool, - datasource_type=datasource_type, + datasource_param=datasource_parameters, + online_drive_request=online_drive_request, ) case DatasourceProviderType.WEBSITE_CRAWL: yield StreamCompletedEvent( @@ -147,23 +134,9 @@ class DatasourceNode(Node[DatasourceNodeData]): related_id = datasource_info.get("related_id") if not related_id: raise DatasourceNodeError("File is not exist") - upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first() - if not upload_file: - raise ValueError("Invalid upload file Info") - file_info = File( - id=upload_file.id, - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=self.tenant_id, - type=FileType.CUSTOM, - transfer_method=FileTransferMethod.LOCAL_FILE, - remote_url=upload_file.source_url, - related_id=upload_file.id, - size=upload_file.size, - storage_key=upload_file.key, - url=upload_file.source_url, + file_info = self.datasource_manager.get_upload_file_by_id( + file_id=related_id, tenant_id=self.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) @@ -201,55 +174,6 @@ class DatasourceNode(Node[DatasourceNodeData]): ) ) - def _generate_parameters( - self, - *, - datasource_parameters: Sequence[DatasourceParameter], - variable_pool: VariablePool, - node_data: DatasourceNodeData, - for_log: bool = False, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - tool_parameters (Sequence[ToolParameter]): The list of tool parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (ToolNodeData): The data associated with the tool node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} - - result: dict[str, Any] = {} - if node_data.datasource_parameters: - for parameter_name in node_data.datasource_parameters: - parameter = datasource_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - datasource_input = node_data.datasource_parameters[parameter_name] - if datasource_input.type == "variable": - variable = variable_pool.get(datasource_input.value) - if variable is None: - raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") - parameter_value = variable.value - elif datasource_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(datasource_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") - result[parameter_name] = parameter_value - - return result - - def _fetch_files(self, variable_pool: VariablePool) -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - @classmethod def _extract_variable_selector_to_variable_mapping( cls, @@ -287,206 +211,6 @@ class DatasourceNode(Node[DatasourceNodeData]): return result - def _transform_message( - self, - messages: Generator[DatasourceMessage, None, None], - parameters_for_log: dict[str, Any], - datasource_info: dict[str, Any], - ) -> Generator: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( - messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json: list[dict | list] = [] - - variables: dict[str, Any] = {} - - for message in message_stream: - match message.type: - case ( - DatasourceMessage.MessageType.IMAGE_LINK - | DatasourceMessage.MessageType.BINARY_LINK - | DatasourceMessage.MessageType.IMAGE - ): - assert isinstance(message.message, DatasourceMessage.TextMessage) - - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE - - datasource_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - case DatasourceMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceMessage.TextMessage) - assert message.meta - - datasource_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"datasource file {datasource_file_id} not exists") - - mapping = { - "tool_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - ) - case DatasourceMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - case DatasourceMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceMessage.JsonMessage) - json.append(message.message.json_object) - case DatasourceMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=stream_text, - is_final=False, - ) - case DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[self._node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - case DatasourceMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) - case ( - DatasourceMessage.MessageType.BLOB_CHUNK - | DatasourceMessage.MessageType.LOG - | DatasourceMessage.MessageType.RETRIEVER_RESOURCES - ): - pass - - # mark the end of the stream - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={**variables}, - metadata={ - WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info, - }, - inputs=parameters_for_log, - ) - ) - @classmethod def version(cls) -> str: return "1" - - def _transform_datasource_file_message( - self, - messages: Generator[DatasourceMessage, None, None], - parameters_for_log: dict[str, Any], - datasource_info: dict[str, Any], - variable_pool: VariablePool, - datasource_type: DatasourceProviderType, - ) -> Generator: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( - messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - ) - file = None - for message in message_stream: - if message.type == DatasourceMessage.MessageType.BINARY_LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE - - datasource_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - if file: - variable_pool.add([self._node_id, "file"], file) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "file": file, - "datasource_type": datasource_type, - }, - ) - ) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index c442e01854..59be4c54ef 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -21,12 +21,12 @@ from docx.table import Table from docx.text.paragraph import Paragraph from core.helper import ssrf_proxy -from core.variables import ArrayFileSegment -from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.file import File, FileTransferMethod, file_manager from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node +from core.workflow.variables import ArrayFileSegment +from core.workflow.variables.segments import ArrayStringSegment, FileSegment from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 8f180b47b5..de14c8c517 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -10,11 +10,9 @@ from urllib.parse import urlencode, urlparse import httpx from json_repair import repair_json -from core.helper.ssrf_proxy import ssrf_proxy -from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.file.enums import FileTransferMethod -from core.workflow.file.file_manager import file_manager as default_file_manager from core.workflow.runtime import VariablePool +from core.workflow.variables.segments import ArrayFileSegment, FileSegment from ..protocols import FileManagerProtocol, HttpClientProtocol from .entities import ( @@ -81,8 +79,8 @@ class Executor: http_request_config: HttpRequestNodeConfig, max_retries: int | None = None, ssl_verify: bool | None = None, - http_client: HttpClientProtocol | None = None, - file_manager: FileManagerProtocol | None = None, + http_client: HttpClientProtocol, + file_manager: FileManagerProtocol, ): self._http_request_config = http_request_config # If authorization API key is present, convert the API key using the variable pool @@ -116,8 +114,8 @@ class Executor: self.max_retries = ( max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries ) - self._http_client = http_client or ssrf_proxy - self._file_manager = file_manager or default_file_manager + self._http_client = http_client + self._file_manager = file_manager # init template self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index d45775652f..11458db758 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -3,18 +3,15 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.helper.ssrf_proxy import ssrf_proxy -from core.tools.tool_file_manager import ToolFileManager -from core.variables.segments import ArrayFileSegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.file import File, FileTransferMethod -from core.workflow.file.file_manager import file_manager as default_file_manager from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.http_request.executor import Executor -from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol +from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol +from core.workflow.variables.segments import ArrayFileSegment from factories import file_factory from .config import build_http_request_config, resolve_http_request_config @@ -45,9 +42,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_runtime_state: "GraphRuntimeState", *, http_request_config: HttpRequestNodeConfig, - http_client: HttpClientProtocol | None = None, - tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, - file_manager: FileManagerProtocol | None = None, + http_client: HttpClientProtocol, + tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], + file_manager: FileManagerProtocol, ) -> None: super().__init__( id=id, @@ -55,10 +52,11 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) + self._http_request_config = http_request_config - self._http_client = http_client or ssrf_proxy + self._http_client = http_client self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager or default_file_manager + self._file_manager = file_manager @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/core/workflow/nodes/human_input/entities.py index 72d4fc675b..a4473dfa7d 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/core/workflow/nodes/human_input/entities.py @@ -10,10 +10,10 @@ from typing import Annotated, Any, ClassVar, Literal, Self from pydantic import BaseModel, Field, field_validator, model_validator -from core.variables.consts import SELECTORS_LENGTH from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool +from core.workflow.variables.consts import SELECTORS_LENGTH from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 25a881ea7d..5e7aa2a751 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs from core.model_runtime.entities.llm_entities import LLMUsage -from core.variables import IntegerVariable, NoneSegment -from core.variables.segments import ArrayAnySegment, ArraySegment -from core.variables.variables import Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import ( NodeExecutionType, @@ -36,6 +33,9 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.node import Node from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.runtime import VariablePool +from core.workflow.variables import IntegerVariable, NoneSegment +from core.workflow.variables.segments import ArrayAnySegment, ArraySegment +from core.workflow.variables.variables import Variable from libs.datetime_utils import naive_utc_now from .exc import ( diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index b25c3a3d29..0cfd39e485 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -5,12 +5,6 @@ from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder -from core.variables import ( - ArrayFileSegment, - FileSegment, - StringSegment, -) -from core.variables.segments import ArrayObjectSegment from core.workflow.entities import GraphInitParams from core.workflow.enums import ( NodeType, @@ -22,6 +16,12 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.node import Node from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source +from core.workflow.variables import ( + ArrayFileSegment, + FileSegment, + StringSegment, +) +from core.workflow.variables.segments import ArrayObjectSegment from .entities import KnowledgeRetrievalNodeData from .exc import ( diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 3978a79550..d9ef16fbe7 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,12 +1,12 @@ from collections.abc import Callable, Sequence from typing import Any, TypeAlias, TypeVar -from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.file import File from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node +from core.workflow.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.workflow.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment from .entities import FilterOperator, ListOperatorNodeData, Order from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 341a1c1a4c..f753e19897 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -5,21 +5,17 @@ from sqlalchemy import select, update from sqlalchemy.orm import Session from configs import dify_config -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment from core.workflow.enums import SystemVariableKey from core.workflow.file.models import File -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import VariablePool +from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.model import Conversation @@ -29,46 +25,14 @@ from models.provider_ids import ModelProviderID from .exc import InvalidVariableTypeError -def fetch_model_config( - *, - node_data_model: ModelConfig, - credentials_provider: CredentialsProvider, - model_factory: ModelFactory, -) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, +def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( + model_instance.model_name, + model_instance.credentials, ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - stop: list[str] = [] - if "stop" in node_data_model.completion_params: - stop = node_data_model.completion_params.pop("stop") - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - return model_instance, ModelConfigWithCredentialsEntity( - provider=node_data_model.provider, - model=node_data_model.name, - model_schema=model_schema, - mode=node_data_model.mode, - provider_model_bundle=provider_model_bundle, - credentials=credentials, - parameters=node_data_model.completion_params, - stop=stop, - ) + raise ValueError(f"Model schema not found for {model_instance.model_name}") + return model_schema def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 0259434d90..057a144e89 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -11,11 +11,9 @@ from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import select -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ( ImagePromptMessageContent, @@ -38,20 +36,12 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.signature import sign_upload_file -from core.variables import ( - ArrayFileSegment, - ArraySegment, - FileSegment, - NoneSegment, - ObjectSegment, - StringSegment, -) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams from core.workflow.enums import ( @@ -72,8 +62,16 @@ from core.workflow.node_events import ( from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory from core.workflow.runtime import VariablePool +from core.workflow.variables import ( + ArrayFileSegment, + ArraySegment, + FileSegment, + NoneSegment, + ObjectSegment, + StringSegment, +) from extensions.ext_database import db from models.dataset import SegmentAttachmentBinding from models.model import UploadFile @@ -83,7 +81,6 @@ from .entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, LLMNodeData, - ModelConfig, ) from .exc import ( InvalidContextStructureError, @@ -116,6 +113,8 @@ class LLMNode(Node[LLMNodeData]): _llm_file_saver: LLMFileSaver _credentials_provider: CredentialsProvider _model_factory: ModelFactory + _model_instance: ModelInstance + _memory: PromptMessageMemory | None def __init__( self, @@ -126,6 +125,8 @@ class LLMNode(Node[LLMNodeData]): *, credentials_provider: CredentialsProvider, model_factory: ModelFactory, + model_instance: ModelInstance, + memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( @@ -139,6 +140,8 @@ class LLMNode(Node[LLMNodeData]): self._credentials_provider = credentials_provider self._model_factory = model_factory + self._model_instance = model_instance + self._memory = memory if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -202,29 +205,12 @@ class LLMNode(Node[LLMNodeData]): node_inputs["#context_files#"] = [file.model_dump() for file in context_files] # fetch model config - model_instance, model_config = self._fetch_model_config( - node_data_model=self.node_data.model, - ) - model_name = getattr(model_instance, "model_name", None) - if not isinstance(model_name, str): - model_name = model_config.model - model_provider = getattr(model_instance, "provider", None) - if not isinstance(model_provider, str): - model_provider = model_config.provider - model_schema = model_instance.model_type_instance.get_model_schema( - model_name, - model_instance.credentials, - ) - if not model_schema: - raise ValueError(f"Model schema not found for {model_name}") + model_instance = self._model_instance + model_name = model_instance.model_name + model_provider = model_instance.provider + model_stop = model_instance.stop - # fetch memory - memory = llm_utils.fetch_memory( - variable_pool=variable_pool, - app_id=self.app_id, - node_data_memory=self.node_data.memory, - model_instance=model_instance, - ) + memory = self._memory query: str | None = None if self.node_data.memory: @@ -240,9 +226,7 @@ class LLMNode(Node[LLMNodeData]): context=context, memory=memory, model_instance=model_instance, - model_schema=model_schema, - model_parameters=self.node_data.model.completion_params, - stop=model_config.stop, + stop=model_stop, prompt_template=self.node_data.prompt_template, memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, @@ -254,7 +238,6 @@ class LLMNode(Node[LLMNodeData]): # handle invoke result generator = LLMNode.invoke_llm( - node_data_model=self.node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -371,7 +354,6 @@ class LLMNode(Node[LLMNodeData]): @staticmethod def invoke_llm( *, - node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, @@ -384,11 +366,10 @@ class LLMNode(Node[LLMNodeData]): node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - model_schema = model_instance.model_type_instance.get_model_schema( - node_data_model.name, model_instance.credentials - ) - if not model_schema: - raise ValueError(f"Model schema not found for {node_data_model.name}") + model_parameters = model_instance.parameters + invoke_model_parameters = dict(model_parameters) + + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) if structured_output_enabled: output_schema = LLMNode.fetch_structured_output_schema( @@ -402,7 +383,7 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, json_schema=output_schema, - model_parameters=node_data_model.completion_params, + model_parameters=invoke_model_parameters, stop=list(stop or []), stream=True, user=user_id, @@ -412,7 +393,7 @@ class LLMNode(Node[LLMNodeData]): invoke_result = model_instance.invoke_llm( prompt_messages=list(prompt_messages), - model_parameters=node_data_model.completion_params, + model_parameters=invoke_model_parameters, stop=list(stop or []), stream=True, user=user_id, @@ -771,33 +752,14 @@ class LLMNode(Node[LLMNodeData]): return None - def _fetch_model_config( - self, - *, - node_data_model: ModelConfig, - ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - model, model_config_with_cred = llm_utils.fetch_model_config( - node_data_model=node_data_model, - credentials_provider=self._credentials_provider, - model_factory=self._model_factory, - ) - completion_params = model_config_with_cred.parameters - - model_config_with_cred.parameters = completion_params - # NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`. - node_data_model.completion_params = completion_params - return model, model_config_with_cred - @staticmethod def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence[File], context: str | None = None, - memory: TokenBufferMemory | None = None, + memory: PromptMessageMemory | None = None, model_instance: ModelInstance, - model_schema: AIModelEntity, - model_parameters: Mapping[str, Any], prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -808,6 +770,7 @@ class LLMNode(Node[LLMNodeData]): context_files: list[File] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) if isinstance(prompt_template, list): # For chat model @@ -826,8 +789,6 @@ class LLMNode(Node[LLMNodeData]): memory=memory, memory_config=memory_config, model_instance=model_instance, - model_schema=model_schema, - model_parameters=model_parameters, ) # Extend prompt_messages with memory messages prompt_messages.extend(memory_messages) @@ -865,8 +826,6 @@ class LLMNode(Node[LLMNodeData]): memory=memory, memory_config=memory_config, model_instance=model_instance, - model_schema=model_schema, - model_parameters=model_parameters, ) # Insert histories into the prompt prompt_content = prompt_messages[0].content @@ -1316,23 +1275,23 @@ def _calculate_rest_token( *, prompt_messages: list[PromptMessage], model_instance: ModelInstance, - model_schema: AIModelEntity, - model_parameters: Mapping[str, Any], ) -> int: rest_tokens = 2000 + runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + runtime_model_parameters = model_instance.parameters - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: + for parameter_rule in runtime_model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_parameters.get(parameter_rule.name) - or model_parameters.get(str(parameter_rule.use_template)) + runtime_model_parameters.get(parameter_rule.name) + or runtime_model_parameters.get(str(parameter_rule.use_template)) or 0 ) @@ -1344,11 +1303,9 @@ def _calculate_rest_token( def _handle_memory_chat_mode( *, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, model_instance: ModelInstance, - model_schema: AIModelEntity, - model_parameters: Mapping[str, Any], ) -> Sequence[PromptMessage]: memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model @@ -1356,8 +1313,6 @@ def _handle_memory_chat_mode( rest_tokens = _calculate_rest_token( prompt_messages=[], model_instance=model_instance, - model_schema=model_schema, - model_parameters=model_parameters, ) memory_messages = memory.get_history_prompt_messages( max_token_limit=rest_tokens, @@ -1368,11 +1323,9 @@ def _handle_memory_chat_mode( def _handle_memory_completion_mode( *, - memory: TokenBufferMemory | None, + memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, model_instance: ModelInstance, - model_schema: AIModelEntity, - model_parameters: Mapping[str, Any], ) -> str: memory_text = "" # Get history text from memory for completion model @@ -1380,20 +1333,51 @@ def _handle_memory_completion_mode( rest_tokens = _calculate_rest_token( prompt_messages=[], model_instance=model_instance, - model_schema=model_schema, - model_parameters=model_parameters, ) if not memory_config.role_prefix: raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = memory.get_history_prompt_text( + memory_messages = memory.get_history_prompt_messages( max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + memory_text = _convert_history_messages_to_text( + history_messages=memory_messages, human_prefix=memory_config.role_prefix.user, ai_prefix=memory_config.role_prefix.assistant, ) return memory_text +def _convert_history_messages_to_text( + *, + history_messages: Sequence[PromptMessage], + human_prefix: str, + ai_prefix: str, +) -> str: + string_messages: list[str] = [] + for message in history_messages: + if message.role == PromptMessageRole.USER: + role = human_prefix + elif message.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue + + if isinstance(message.content, list): + content_parts = [] + for content in message.content: + if isinstance(content, TextPromptMessageContent): + content_parts.append(content.data) + elif isinstance(content, ImagePromptMessageContent): + content_parts.append("[image]") + + inner_msg = "\n".join(content_parts) + string_messages.append(f"{role}: {inner_msg}") + else: + string_messages.append(f"{role}: {message.content}") + return "\n".join(string_messages) + + def _handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, diff --git a/api/core/workflow/nodes/llm/protocols.py b/api/core/workflow/nodes/llm/protocols.py index 8e0365299d..5bca04165a 100644 --- a/api/core/workflow/nodes/llm/protocols.py +++ b/api/core/workflow/nodes/llm/protocols.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Sequence from typing import Any, Protocol from core.model_manager import ModelInstance +from core.model_runtime.entities import PromptMessage class CredentialsProvider(Protocol): @@ -19,3 +21,13 @@ class ModelFactory(Protocol): def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: """Create a model instance that is ready for schema lookup and invocation.""" ... + + +class PromptMessageMemory(Protocol): + """Port for loading memory as prompt messages for LLM nodes.""" + + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: int | None = None + ) -> Sequence[PromptMessage]: + """Return historical prompt messages constrained by token/message limits.""" + ... diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 92a8702fc3..4090f27799 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -3,9 +3,9 @@ from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator -from core.variables.types import SegmentType from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData from core.workflow.utils.condition.entities import Condition +from core.workflow.variables.types import SegmentType _VALID_VAR_TYPE = frozenset( [ diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 241a186a94..c546df1fba 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -6,7 +6,6 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, cast from core.model_runtime.entities.llm_entities import LLMUsage -from core.variables import Segment, SegmentType from core.workflow.enums import ( NodeExecutionType, NodeType, @@ -31,6 +30,7 @@ from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.node import Node from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData from core.workflow.utils.condition.processor import ConditionProcessor +from core.workflow.variables import Segment, SegmentType from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable from libs.datetime_utils import naive_utc_now diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 4e3819c4cf..90d78ae429 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -8,9 +8,9 @@ from pydantic import ( ) from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.variables.types import SegmentType from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig +from core.workflow.variables.types import SegmentType _OLD_BOOL_TYPE_NAME = "bool" _OLD_SELECT_TYPE_NAME = "select" diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py index a1707a2461..5a58780575 100644 --- a/api/core/workflow/nodes/parameter_extractor/exc.py +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -1,6 +1,6 @@ from typing import Any -from core.variables.types import SegmentType +from core.workflow.variables.types import SegmentType class ParameterExtractorNodeError(ValueError): diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index f549d44efa..66ef17e585 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -5,7 +5,6 @@ import uuid from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ImagePromptMessageContent @@ -25,14 +24,14 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.variables.types import ArrayValidation, SegmentType from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.file import File from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser from core.workflow.nodes.base.node import Node -from core.workflow.nodes.llm import ModelConfig, llm_utils +from core.workflow.nodes.llm import llm_utils from core.workflow.runtime import VariablePool +from core.workflow.variables.types import ArrayValidation, SegmentType from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData @@ -95,8 +94,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_type = NodeType.PARAMETER_EXTRACTOR - _model_instance: ModelInstance | None = None - _model_config: ModelConfigWithCredentialsEntity | None = None + _model_instance: ModelInstance _credentials_provider: "CredentialsProvider" _model_factory: "ModelFactory" @@ -109,6 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): *, credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", + model_instance: ModelInstance, ) -> None: super().__init__( id=id, @@ -118,6 +117,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) self._credentials_provider = credentials_provider self._model_factory = model_factory + self._model_instance = model_instance @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -155,18 +155,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): else [] ) - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance = self._model_instance if not isinstance(model_instance.model_type_instance, LargeLanguageModel): raise InvalidModelTypeError("Model is not a Large Language Model") - llm_model = model_instance.model_type_instance - model_schema = llm_model.get_model_schema( - model=model_config.model, - credentials=model_config.credentials, - ) - if not model_schema: - raise ModelSchemaNotFoundError("Model schema not found") - + try: + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + except ValueError as exc: + raise ModelSchemaNotFoundError("Model schema not found") from exc # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -184,7 +180,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data=node_data, query=query, variable_pool=self.graph_runtime_state.variable_pool, - model_config=model_config, + model_instance=model_instance, memory=memory, files=files, vision_detail=node_data.vision.configs.detail, @@ -195,7 +191,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): data=node_data, query=query, variable_pool=self.graph_runtime_state.variable_pool, - model_config=model_config, + model_instance=model_instance, memory=memory, files=files, vision_detail=node_data.vision.configs.detail, @@ -211,24 +207,23 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): } process_data = { - "model_mode": model_config.mode, + "model_mode": node_data.model.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, prompt_messages=prompt_messages + model_mode=node_data.model.mode, prompt_messages=prompt_messages ), "usage": None, "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), "tool_call": None, - "model_provider": model_config.provider, - "model_name": model_config.model, + "model_provider": model_instance.provider, + "model_name": model_instance.model_name, } try: text, usage, tool_call = self._invoke( - node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, tools=prompt_message_tools, - stop=model_config.stop, + stop=model_instance.stop, ) process_data["usage"] = jsonable_encoder(usage) process_data["tool_call"] = jsonable_encoder(tool_call) @@ -290,17 +285,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def _invoke( self, - node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], - stop: list[str], + stop: Sequence[str], ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=node_data_model.completion_params, + model_parameters=dict(model_instance.parameters), tools=tools, - stop=stop, + stop=list(stop), stream=False, user=self.user_id, ) @@ -324,7 +318,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, memory: TokenBufferMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -337,7 +331,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + rest_token = self._calculate_rest_token( + node_data=node_data, + query=query, + variable_pool=variable_pool, + model_instance=model_instance, + context="", + ) prompt_template = self._get_function_calling_prompt_template( node_data, query, variable_pool, memory, rest_token ) @@ -349,7 +349,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): context="", memory_config=node_data.memory, memory=None, - model_config=model_config, + model_instance=model_instance, image_detail_config=vision_detail, ) @@ -406,7 +406,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, memory: TokenBufferMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -421,7 +421,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data=data, query=query, variable_pool=variable_pool, - model_config=model_config, + model_instance=model_instance, memory=memory, files=files, vision_detail=vision_detail, @@ -431,7 +431,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data=data, query=query, variable_pool=variable_pool, - model_config=model_config, + model_instance=model_instance, memory=memory, files=files, vision_detail=vision_detail, @@ -444,7 +444,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, memory: TokenBufferMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -454,7 +454,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( - node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + node_data=node_data, + query=query, + variable_pool=variable_pool, + model_instance=model_instance, + context="", ) prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token @@ -467,7 +471,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): context="", memory_config=node_data.memory, memory=memory, - model_config=model_config, + model_instance=model_instance, image_detail_config=vision_detail, ) @@ -478,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, memory: TokenBufferMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -488,7 +492,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( - node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + node_data=node_data, + query=query, + variable_pool=variable_pool, + model_instance=model_instance, + context="", ) prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, @@ -508,7 +516,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): context="", memory_config=node_data.memory, memory=None, - model_config=model_config, + model_instance=model_instance, image_detail_config=vision_detail, ) @@ -769,21 +777,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, context: str | None, ) -> int: + try: + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + except ValueError as exc: + raise ModelSchemaNotFoundError("Model schema not found") from exc prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - model_instance, model_config = self._fetch_model_config(node_data.model) - if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise InvalidModelTypeError("Model is not a Large Language Model") - - llm_model = model_instance.model_type_instance - model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) - if not model_schema: - raise ModelSchemaNotFoundError("Model schema not found") - - if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: + if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) else: prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) @@ -796,27 +799,28 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): context=context, memory_config=node_data.memory, memory=None, - model_config=model_config, + model_instance=model_instance, ) rest_tokens = 2000 - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - + model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) curr_message_tokens = ( - model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000 + model_type_instance.get_num_tokens( + model_instance.model_name, model_instance.credentials, prompt_messages + ) + + 1000 ) # add 1000 to ensure tool call messages max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: + for parameter_rule in model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") + model_instance.parameters.get(parameter_rule.name) + or model_instance.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -824,21 +828,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): return rest_tokens - def _fetch_model_config( - self, node_data_model: ModelConfig - ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - """ - Fetch model config. - """ - if not self._model_instance or not self._model_config: - self._model_instance, self._model_config = llm_utils.fetch_model_config( - node_data_model=node_data_model, - credentials_provider=self._credentials_provider, - model_factory=self._model_factory, - ) - - return self._model_instance, self._model_config - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py index a1f3e20835..fda524d701 100644 --- a/api/core/workflow/nodes/protocols.py +++ b/api/core/workflow/nodes/protocols.py @@ -27,3 +27,16 @@ class HttpClientProtocol(Protocol): class FileManagerProtocol(Protocol): def download(self, f: File, /) -> bytes: ... + + +class ToolFileManagerProtocol(Protocol): + def create_file_by_raw( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: str | None, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ) -> Any: ... diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 3f41c0d0b7..464d9b6b9c 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -3,12 +3,10 @@ import re from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities import GraphInitParams @@ -22,7 +20,12 @@ from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils +from core.workflow.nodes.llm import ( + LLMNode, + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + llm_utils, +) from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from libs.json_in_md_parser import parse_and_check_json_markdown @@ -52,6 +55,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _llm_file_saver: LLMFileSaver _credentials_provider: "CredentialsProvider" _model_factory: "ModelFactory" + _model_instance: ModelInstance def __init__( self, @@ -62,6 +66,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): *, credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", + model_instance: ModelInstance, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( @@ -75,6 +80,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self._credentials_provider = credentials_provider self._model_factory = model_factory + self._model_instance = model_instance if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -95,18 +101,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None query = variable.value if variable else None variables = {"query": query} - # fetch model config - model_instance, model_config = llm_utils.fetch_model_config( - node_data_model=node_data.model, - credentials_provider=self._credentials_provider, - model_factory=self._model_factory, - ) - model_schema = model_instance.model_type_instance.get_model_schema( - model_instance.model_name, - model_instance.credentials, - ) - if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") + # fetch model instance + model_instance = self._model_instance # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -131,7 +127,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): rest_token = self._calculate_rest_token( node_data=node_data, query=query or "", - model_config=model_config, + model_instance=model_instance, context="", ) prompt_template = self._get_prompt_template( @@ -149,9 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): sys_query="", memory=memory, model_instance=model_instance, - model_schema=model_schema, - model_parameters=node_data.model.completion_params, - stop=model_config.stop, + stop=model_instance.stop, sys_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, @@ -166,7 +160,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): try: # handle invoke result generator = LLMNode.invoke_llm( - node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -205,14 +198,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): category_name = classes_map[category_id_result] category_id = category_id_result process_data = { - "model_mode": model_config.mode, + "model_mode": node_data.model.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, prompt_messages=prompt_messages + model_mode=node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), "finish_reason": finish_reason, - "model_provider": model_config.provider, - "model_name": model_config.model, + "model_provider": model_instance.provider, + "model_name": model_instance.model_name, } outputs = { "class_name": category_name, @@ -285,39 +278,40 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self, node_data: QuestionClassifierNodeData, query: str, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, context: str | None, ) -> int: - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages = prompt_transform.get_prompt( + prompt_messages, _ = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, - inputs={}, - query="", - files=[], + sys_query="", + sys_files=[], context=context, - memory_config=node_data.memory, memory=None, - model_config=model_config, + model_instance=model_instance, + stop=model_instance.stop, + memory_config=node_data.memory, + vision_enabled=False, + vision_detail=node_data.vision.configs.detail, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=[], ) rest_tokens = 2000 - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: + for parameter_rule in model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") + model_instance.parameters.get(parameter_rule.name) + or model_instance.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index a7bf7d6642..0d7270a282 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -11,8 +11,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.variables.segments import ArrayAnySegment, ArrayFileSegment -from core.variables.variables import ArrayAnyVariable from core.workflow.enums import ( NodeType, SystemVariableKey, @@ -23,6 +21,8 @@ from core.workflow.file import File, FileTransferMethod from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment +from core.workflow.variables.variables import ArrayAnyVariable from extensions.ext_database import db from factories import file_factory from models import ToolFile diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 060afd6ae6..9f6046c11a 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,14 +2,14 @@ import logging from collections.abc import Mapping from typing import Any -from core.variables.types import SegmentType -from core.variables.variables import FileVariable from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import NodeExecutionType, NodeType from core.workflow.file import FileTransferMethod from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node +from core.workflow.variables.types import SegmentType +from core.workflow.variables.variables import FileVariable from factories import file_factory from factories.variable_factory import build_segment_with_type diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index aab17aad22..febbf1d1d6 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,7 +1,7 @@ from pydantic import BaseModel -from core.variables.types import SegmentType from core.workflow.nodes.base import BaseNodeData +from core.workflow.variables.types import SegmentType class AdvancedSettings(BaseModel): diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 4b3a2304e7..762b7dab07 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,10 +1,10 @@ from collections.abc import Mapping -from core.variables.segments import Segment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorNodeData +from core.workflow.variables.segments import Segment class VariableAggregatorNode(Node[VariableAggregatorNodeData]): diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 04a7323739..37fde9d1b0 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -3,9 +3,9 @@ from typing import Any, TypeVar from pydantic import BaseModel -from core.variables import Segment -from core.variables.consts import SELECTORS_LENGTH -from core.variables.types import SegmentType +from core.workflow.variables import Segment +from core.workflow.variables.consts import SELECTORS_LENGTH +from core.workflow.variables.types import SegmentType # Use double underscore (`__`) prefix for internal variables # to minimize risk of collision with user-defined variable names. diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 9f5818f4bb..b987949541 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,7 +1,6 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, VariableBase from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -9,6 +8,7 @@ from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from core.workflow.variables import SegmentType, VariableBase from .node_data import VariableAssignerData, WriteMode diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index f5490fb900..ce3fe9620c 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -1,6 +1,6 @@ from typing import Any -from core.variables import SegmentType +from core.workflow.variables import SegmentType from .enums import Operation diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 5857702e72..0d4c3d2774 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -2,14 +2,14 @@ import json from collections.abc import Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any -from core.variables import SegmentType, VariableBase -from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from core.workflow.variables import SegmentType, VariableBase +from core.workflow.variables.consts import SELECTORS_LENGTH from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem diff --git a/api/core/workflow/repositories/datasource_manager_protocol.py b/api/core/workflow/repositories/datasource_manager_protocol.py new file mode 100644 index 0000000000..4acf486bef --- /dev/null +++ b/api/core/workflow/repositories/datasource_manager_protocol.py @@ -0,0 +1,50 @@ +from collections.abc import Generator +from typing import Any, Protocol + +from pydantic import BaseModel + +from core.workflow.file import File +from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent + + +class DatasourceParameter(BaseModel): + workspace_id: str + page_id: str + type: str + + +class OnlineDriveDownloadFileParam(BaseModel): + id: str + bucket: str + + +class DatasourceFinal(BaseModel): + data: dict[str, Any] | None = None + + +class DatasourceManagerProtocol(Protocol): + @classmethod + def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: ... + + @classmethod + def stream_node_events( + cls, + *, + node_id: str, + user_id: str, + datasource_name: str, + datasource_type: str, + provider_id: str, + tenant_id: str, + provider: str, + plugin_id: str, + credential_id: str, + parameters_for_log: dict[str, Any], + datasource_info: dict[str, Any], + variable_pool: Any, + datasource_param: DatasourceParameter | None = None, + online_drive_request: OnlineDriveDownloadFileParam | None = None, + ) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: ... + + @classmethod + def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: ... diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/core/workflow/runtime/graph_runtime_state_protocol.py index bfbb5ba704..81d87e5a74 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/core/workflow/runtime/graph_runtime_state_protocol.py @@ -2,8 +2,8 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.model_runtime.entities.llm_entities import LLMUsage -from core.variables.segments import Segment from core.workflow.system_variable import SystemVariableReadOnlyView +from core.workflow.variables.segments import Segment class ReadOnlyVariablePool(Protocol): diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/core/workflow/runtime/read_only_wrappers.py index d3e4c60d9b..25a834a539 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/core/workflow/runtime/read_only_wrappers.py @@ -5,8 +5,8 @@ from copy import deepcopy from typing import Any from core.model_runtime.entities.llm_entities import LLMUsage -from core.variables.segments import Segment from core.workflow.system_variable import SystemVariableReadOnlyView +from core.workflow.variables.segments import Segment from .graph_runtime_state import GraphRuntimeState from .variable_pool import VariablePool diff --git a/api/core/workflow/runtime/variable_pool.py b/api/core/workflow/runtime/variable_pool.py index 0ba9d8b3a8..48ad102b43 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/core/workflow/runtime/variable_pool.py @@ -8,10 +8,6 @@ from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field -from core.variables import Segment, SegmentGroup, VariableBase -from core.variables.consts import SELECTORS_LENGTH -from core.variables.segments import FileSegment, ObjectSegment -from core.variables.variables import RAGPipelineVariableInput, Variable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, @@ -20,6 +16,10 @@ from core.workflow.constants import ( ) from core.workflow.file import File, FileAttribute, file_manager from core.workflow.system_variable import SystemVariable +from core.workflow.variables import Segment, SegmentGroup, VariableBase +from core.workflow.variables.consts import SELECTORS_LENGTH +from core.workflow.variables.segments import FileSegment, ObjectSegment +from core.workflow.variables.variables import RAGPipelineVariableInput, Variable from factories import variable_factory VariableValue = Union[str, int, float, dict[str, object], list[object], File] diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index c3f25a4d62..4e635cc2f2 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -2,10 +2,10 @@ import json from collections.abc import Mapping, Sequence from typing import Literal, NamedTuple -from core.variables import ArrayFileSegment -from core.variables.segments import ArrayBooleanSegment, BooleanSegment from core.workflow.file import FileAttribute, file_manager from core.workflow.runtime import VariablePool +from core.workflow.variables import ArrayFileSegment +from core.workflow.variables.segments import ArrayBooleanSegment, BooleanSegment from .entities import Condition, SubCondition, SupportedComparisonOperator diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index 7992785fe1..dfa4ce2e75 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -2,9 +2,9 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.variables import VariableBase -from core.variables.consts import SELECTORS_LENGTH from core.workflow.runtime import VariablePool +from core.workflow.variables import VariableBase +from core.workflow.variables.consts import SELECTORS_LENGTH class VariableLoader(Protocol): diff --git a/api/core/variables/__init__.py b/api/core/workflow/variables/__init__.py similarity index 100% rename from api/core/variables/__init__.py rename to api/core/workflow/variables/__init__.py diff --git a/api/core/variables/consts.py b/api/core/workflow/variables/consts.py similarity index 100% rename from api/core/variables/consts.py rename to api/core/workflow/variables/consts.py diff --git a/api/core/variables/exc.py b/api/core/workflow/variables/exc.py similarity index 100% rename from api/core/variables/exc.py rename to api/core/workflow/variables/exc.py diff --git a/api/core/variables/segment_group.py b/api/core/workflow/variables/segment_group.py similarity index 100% rename from api/core/variables/segment_group.py rename to api/core/workflow/variables/segment_group.py diff --git a/api/core/variables/segments.py b/api/core/workflow/variables/segments.py similarity index 100% rename from api/core/variables/segments.py rename to api/core/workflow/variables/segments.py diff --git a/api/core/variables/types.py b/api/core/workflow/variables/types.py similarity index 100% rename from api/core/variables/types.py rename to api/core/workflow/variables/types.py diff --git a/api/core/variables/utils.py b/api/core/workflow/variables/utils.py similarity index 100% rename from api/core/variables/utils.py rename to api/core/workflow/variables/utils.py diff --git a/api/core/variables/variables.py b/api/core/workflow/variables/variables.py similarity index 95% rename from api/core/variables/variables.py rename to api/core/workflow/variables/variables.py index 338d81df78..af866283da 100644 --- a/api/core/variables/variables.py +++ b/api/core/workflow/variables/variables.py @@ -4,8 +4,6 @@ from uuid import uuid4 from pydantic import BaseModel, Discriminator, Field, Tag -from core.helper import encrypter - from .segments import ( ArrayAnySegment, ArrayBooleanSegment, @@ -27,6 +25,14 @@ from .segments import ( from .types import SegmentType +def _obfuscated_token(token: str) -> str: + if not token: + return token + if len(token) <= 8: + return "*" * 20 + return token[:6] + "*" * 12 + token[-2:] + + class VariableBase(Segment): """ A variable is a segment that has a name. @@ -86,7 +92,7 @@ class SecretVariable(StringVariable): @property def log(self) -> str: - return encrypter.obfuscated_token(self.value) + return _obfuscated_token(self.value) class NoneVariable(NoneSegment, VariableBase): diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 93c6a31960..a192b884f7 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -4,8 +4,8 @@ from typing import Any, overload from pydantic import BaseModel -from core.variables import Segment from core.workflow.file.models import File +from core.workflow.variables import Segment class WorkflowRuntimeTypeConverter: diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index c6589dd99f..66d1c977d6 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -9,11 +9,11 @@ from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel -from core.variables import Segment from core.workflow.enums import NodeType from core.workflow.file.models import File from core.workflow.graph_events import GraphNodeEventBase from core.workflow.nodes.base.node import Node +from core.workflow.variables import Segment from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index fc151af691..82cb865b8b 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -8,9 +8,9 @@ from typing import Any from opentelemetry.trace import Span -from core.variables import Segment from core.workflow.graph_events import GraphNodeEventBase from core.workflow.nodes.base.node import Node +from core.workflow.variables import Segment from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index a7cfb6a65e..b74d9517f4 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -3,8 +3,13 @@ from typing import Any, cast from uuid import uuid4 from configs import dify_config -from core.variables.exc import VariableError -from core.variables.segments import ( +from core.workflow.constants import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) +from core.workflow.file import File +from core.workflow.variables.exc import VariableError +from core.workflow.variables.segments import ( ArrayAnySegment, ArrayBooleanSegment, ArrayFileSegment, @@ -21,8 +26,8 @@ from core.variables.segments import ( Segment, StringSegment, ) -from core.variables.types import SegmentType -from core.variables.variables import ( +from core.workflow.variables.types import SegmentType +from core.workflow.variables.variables import ( ArrayAnyVariable, ArrayBooleanVariable, ArrayFileVariable, @@ -39,11 +44,6 @@ from core.variables.variables import ( StringVariable, VariableBase, ) -from core.workflow.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) -from core.workflow.file import File class UnsupportedSegmentTypeError(Exception): diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index b2b793d40e..461c163e2f 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -1,7 +1,7 @@ from typing import TypedDict -from core.variables.segments import Segment -from core.variables.types import SegmentType +from core.workflow.variables.segments import Segment +from core.workflow.variables.types import SegmentType class _VarTypedDict(TypedDict, total=False): diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 2755f77f61..019949e105 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields from core.helper import encrypter -from core.variables import SecretVariable, SegmentType, VariableBase +from core.workflow.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/libs/login.py b/api/libs/login.py index 73caa492fe..69e2b58426 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -13,6 +13,8 @@ from libs.token import check_csrf_token from models import Account if TYPE_CHECKING: + from flask.typing import ResponseReturnValue + from models.model import EndUser @@ -38,7 +40,7 @@ P = ParamSpec("P") R = TypeVar("R") -def login_required(func: Callable[P, R]): +def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]: """ If you decorate a view with this, it will ensure that the current user is logged in and authenticated before calling the actual view. (If they are @@ -73,7 +75,7 @@ def login_required(func: Callable[P, R]): """ @wraps(func) - def decorated_view(*args: P.args, **kwargs: P.kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif current_user is not None and not current_user.is_authenticated: diff --git a/api/models/workflow.py b/api/models/workflow.py index c88a48632a..6a86251216 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -22,8 +22,6 @@ from sqlalchemy import ( from sqlalchemy.orm import Mapped, declared_attr, mapped_column from typing_extensions import deprecated -from core.variables import utils as variable_utils -from core.variables.variables import FloatVariable, IntegerVariable, StringVariable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, @@ -33,6 +31,8 @@ from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, from core.workflow.enums import NodeType, WorkflowExecutionStatus from core.workflow.file.constants import maybe_file_object from core.workflow.file.models import File +from core.workflow.variables import utils as variable_utils +from core.workflow.variables.variables import FloatVariable, IntegerVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -46,7 +46,7 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.variables import SecretVariable, Segment, SegmentType, VariableBase +from core.workflow.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper diff --git a/api/services/account_service.py b/api/services/account_service.py index b4b25a1194..648b5e834f 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -289,6 +289,12 @@ class AccountService: TenantService.create_owner_tenant_if_not_exist(account=account) + # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). + if dify_config.ENTERPRISE_ENABLED: + from services.enterprise.enterprise_service import try_join_default_workspace + + try_join_default_workspace(str(account.id)) + return account @staticmethod @@ -1407,6 +1413,12 @@ class RegisterService: tenant_was_created.send(tenant) db.session.commit() + + # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). + if dify_config.ENTERPRISE_ENABLED: + from services.enterprise.enterprise_service import try_join_default_workspace + + try_join_default_workspace(str(account.id)) except WorkSpaceNotAllowedCreateError: db.session.rollback() logger.exception("Register failed") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 295d48d8a1..4c87150cf7 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -10,7 +10,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator -from core.variables.types import SegmentType +from core.workflow.variables.types import SegmentType from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now @@ -180,6 +180,14 @@ class ConversationService: @classmethod def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None): + """ + Delete a conversation only if it belongs to the given user and app context. + + Raises: + ConversationNotExistsError: When the conversation is not visible to the current user. + """ + conversation = cls.get_conversation(app_model, conversation_id, user) + try: logger.info( "Initiating conversation deletion for app_name %s, conversation_id: %s", @@ -187,10 +195,10 @@ class ConversationService: conversation_id, ) - db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) + db.session.delete(conversation) db.session.commit() - delete_conversation_related_data.delay(conversation_id) + delete_conversation_related_data.delay(conversation.id) except Exception as e: db.session.rollback() diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 92008d5ff1..b0012d6f6a 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from core.variables.variables import VariableBase +from core.workflow.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index e3832475aa..744b7992f8 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -39,6 +39,9 @@ class BaseRequest: endpoint: str, json: Any | None = None, params: Mapping[str, Any] | None = None, + *, + timeout: float | httpx.Timeout | None = None, + raise_for_status: bool = False, ) -> Any: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} url = f"{cls.base_url}{endpoint}" @@ -53,7 +56,16 @@ class BaseRequest: logger.debug("Failed to generate traceparent header", exc_info=True) with httpx.Client(mounts=mounts) as client: - response = client.request(method, url, json=json, params=params, headers=headers) + # IMPORTANT: + # - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default. + # - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set. + request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers} + if timeout is not None: + request_kwargs["timeout"] = timeout + + response = client.request(method, url, **request_kwargs) + if raise_for_status: + response.raise_for_status() return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index a5133dfcb4..71d456aa2d 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,9 +1,16 @@ +import logging +import uuid from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator +from configs import dify_config from services.enterprise.base import EnterpriseRequest +logger = logging.getLogger(__name__) + +DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0 + class WebAppSettings(BaseModel): access_mode: str = Field( @@ -30,6 +37,55 @@ class WorkspacePermission(BaseModel): ) +class DefaultWorkspaceJoinResult(BaseModel): + """ + Result of ensuring an account is a member of the enterprise default workspace. + + - joined=True is idempotent (already a member also returns True) + - joined=False means enterprise default workspace is not configured or invalid/archived + """ + + workspace_id: str = Field(default="", alias="workspaceId") + joined: bool + message: str + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + @model_validator(mode="after") + def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult": + if self.joined and not self.workspace_id: + raise ValueError("workspace_id must be non-empty when joined is True") + return self + + +def try_join_default_workspace(account_id: str) -> None: + """ + Enterprise-only side-effect: ensure account is a member of the default workspace. + + This is a best-effort integration. Failures must not block user registration. + """ + + if not dify_config.ENTERPRISE_ENABLED: + return + + try: + result = EnterpriseService.join_default_workspace(account_id=account_id) + if result.joined: + logger.info( + "Joined enterprise default workspace for account %s (workspace_id=%s)", + account_id, + result.workspace_id, + ) + else: + logger.info( + "Skipped joining enterprise default workspace for account %s (message=%s)", + account_id, + result.message, + ) + except Exception: + logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True) + + class EnterpriseService: @classmethod def get_info(cls): @@ -39,6 +95,34 @@ class EnterpriseService: def get_workspace_info(cls, tenant_id: str): return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") + @classmethod + def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult: + """ + Call enterprise inner API to add an account to the default workspace. + + NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix, + so the endpoint here is `/default-workspace/members`. + """ + + # Ensure we are sending a UUID-shaped string (enterprise side validates too). + try: + uuid.UUID(account_id) + except ValueError as e: + raise ValueError(f"account_id must be a valid UUID: {account_id}") from e + + data = EnterpriseRequest.send_request( + "POST", + "/default-workspace/members", + json={"account_id": account_id}, + timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS, + raise_for_status=True, + ) + if not isinstance(data, dict): + raise ValueError("Invalid response format from enterprise default workspace API") + if "joined" not in data or "message" not in data: + raise ValueError("Invalid response payload from enterprise default workspace API") + return DefaultWorkspaceJoinResult.model_validate(data) + @classmethod def get_app_sso_settings_last_update_time(cls) -> datetime: data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time") diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 4ae3496cd6..c0f9e4f323 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,7 +36,6 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.variables.variables import VariableBase from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -52,6 +51,7 @@ from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_M from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables.variables import VariableBase from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index edbc7e0cc8..75a1350e60 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -16,9 +16,9 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.tool_file_manager import ToolFileManager -from core.variables.types import SegmentType from core.workflow.enums import NodeType from core.workflow.file.models import FileTransferMethod +from core.workflow.variables.types import SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 056ea4d78a..12be12776a 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -6,7 +6,9 @@ from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload from configs import dify_config -from core.variables.segments import ( +from core.workflow.file.models import File +from core.workflow.nodes.variable_assigner.common.helpers import UpdatedVariable +from core.workflow.variables.segments import ( ArrayFileSegment, ArraySegment, BooleanSegment, @@ -18,9 +20,7 @@ from core.variables.segments import ( Segment, StringSegment, ) -from core.variables.utils import dumps_with_segments -from core.workflow.file.models import File -from core.workflow.nodes.variable_assigner.common.helpers import UpdatedVariable +from core.workflow.variables.utils import dumps_with_segments _MAX_DEPTH = 100 diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 991925ae6b..18ad6c5c16 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -14,20 +14,20 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom -from core.variables import Segment, StringSegment, VariableBase -from core.variables.consts import SELECTORS_LENGTH -from core.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from core.variables.types import SegmentType -from core.variables.utils import dumps_with_segments from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.enums import SystemVariableKey from core.workflow.file.models import File from core.workflow.nodes import NodeType from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables from core.workflow.variable_loader import VariableLoader +from core.workflow.variables import Segment, StringSegment, VariableBase +from core.workflow.variables.consts import SELECTORS_LENGTH +from core.workflow.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from core.workflow.variables.types import SegmentType +from core.workflow.variables.utils import dumps_with_segments from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index abcd41b1be..406fdae525 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -15,8 +15,6 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.variables import VariableBase -from core.variables.variables import Variable from core.workflow.entities import GraphInitParams, WorkflowNodeExecution from core.workflow.entities.pause_reason import HumanInputRequired from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -41,6 +39,8 @@ from core.workflow.repositories.human_input_form_repository import FormCreatePar from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import load_into_variable_pool +from core.workflow.variables import VariableBase +from core.workflow.variables.variables import Variable from core.workflow.workflow_entry import WorkflowEntry from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py new file mode 100644 index 0000000000..003bb356e5 --- /dev/null +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -0,0 +1,42 @@ +from collections.abc import Generator + +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceMessage +from core.workflow.node_events import StreamCompletedEvent + + +def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: + # produce a streamed variable "a"="xy" + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="x", stream=True), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="y", stream=True), + meta=None, + ) + + +def test_stream_node_events_accumulates_variables(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_var_stream()) + events = list( + DatasourceManager.stream_node_events( + node_id="A", + user_id="u", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={"user_id": "u"}, + variable_pool=mocker.Mock(), + datasource_param=type("P", (), {"workspace_id": "w", "page_id": "pg", "type": "t"})(), + online_drive_request=None, + ) + ) + assert isinstance(events[-1], StreamCompletedEvent) diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py new file mode 100644 index 0000000000..909d6377ce --- /dev/null +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -0,0 +1,84 @@ +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.node_events import NodeRunResult, StreamCompletedEvent +from core.workflow.nodes.datasource.datasource_node import DatasourceNode + + +class _Seg: + def __init__(self, v): + self.value = v + + +class _VarPool: + def __init__(self, data): + self.data = data + + def get(self, path): + d = self.data + for k in path: + d = d[k] + return _Seg(d) + + def add(self, *_a, **_k): + pass + + +class _GS: + def __init__(self, vp): + self.variable_pool = vp + + +class _GP: + tenant_id = "t1" + app_id = "app-1" + workflow_id = "wf-1" + graph_config = {} + user_id = "u1" + user_from = "account" + invoke_from = "debugger" + call_depth = 0 + + +def test_node_integration_minimal_stream(mocker): + sys_d = { + "sys": { + "datasource_type": "online_document", + "datasource_info": {"workspace_id": "w", "page": {"page_id": "pg", "type": "t"}, "credential_id": ""}, + } + } + vp = _VarPool(sys_d) + + class _Mgr: + @classmethod + def get_icon_url(cls, **_): + return "icon" + + @classmethod + def stream_node_events(cls, **_): + yield from () + yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)) + + @classmethod + def get_upload_file_by_id(cls, **_): + raise AssertionError + + node = DatasourceNode( + id="n", + config={ + "id": "n", + "data": { + "type": "datasource", + "version": "1", + "title": "Datasource", + "provider_type": "plugin", + "provider_name": "p", + "plugin_id": "plug", + "datasource_name": "ds", + }, + }, + graph_init_params=_GP(), + graph_runtime_state=_GS(vp), + datasource_manager=_Mgr, + ) + + out = list(node._run()) + assert isinstance(out[-1], StreamCompletedEvent) diff --git a/api/tests/integration_tests/libs/test_api_token_cache_integration.py b/api/tests/integration_tests/libs/test_api_token_cache_integration.py index 166fcb515f..1d7b835fd2 100644 --- a/api/tests/integration_tests/libs/test_api_token_cache_integration.py +++ b/api/tests/integration_tests/libs/test_api_token_cache_integration.py @@ -360,7 +360,7 @@ class TestEndToEndCacheFlow: class TestRedisFailover: """Test behavior when Redis is unavailable.""" - @patch("services.api_token_service.redis_client") + @patch("services.api_token_service.redis_client", autospec=True) def test_graceful_degradation_when_redis_fails(self, mock_redis): """Test system degrades gracefully when Redis is unavailable.""" from redis import RedisError diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index f3a5ba0d11..5faa002fff 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -6,11 +6,11 @@ import pytest from sqlalchemy import delete from sqlalchemy.orm import Session -from core.variables.segments import StringSegment -from core.variables.types import SegmentType -from core.variables.variables import StringVariable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes import NodeType +from core.workflow.variables.segments import StringSegment +from core.workflow.variables.types import SegmentType +from core.workflow.variables.variables import StringVariable from extensions.ext_database import db from extensions.ext_storage import storage from factories.variable_factory import build_segment diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index d020233620..a259ccb2b9 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -5,7 +5,7 @@ import pytest from sqlalchemy import delete from core.db.session_factory import session_factory -from core.variables.segments import StringSegment +from core.workflow.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -191,7 +191,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant - from core.variables.types import SegmentType + from core.workflow.variables.types import SegmentType from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -422,7 +422,7 @@ class TestDeleteDraftVariablesSessionCommit: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" - from core.variables.types import SegmentType + from core.workflow.variables.types import SegmentType from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py index 210dee4c36..81ebb1d2f7 100644 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -41,17 +41,15 @@ class TestOpenSearchConfig: assert params["connection_class"].__name__ == "Urllib3HttpConnection" assert params["http_auth"] == ("admin", "password") - @patch("boto3.Session") - @patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth") + @patch("boto3.Session", autospec=True) + @patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth", autospec=True) def test_to_opensearch_params_with_aws_managed_iam( self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock ): mock_credentials = MagicMock() mock_boto_session.return_value.get_credentials.return_value = mock_credentials - mock_auth_instance = MagicMock() - mock_aws_signer_auth.return_value = mock_auth_instance - + mock_auth_instance = mock_aws_signer_auth.return_value aws_region = "ap-southeast-2" aws_service = "aoss" host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com" @@ -157,7 +155,7 @@ class TestOpenSearchVector: doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch("opensearchpy.helpers.bulk") as mock_bulk: + with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) @@ -171,7 +169,7 @@ class TestOpenSearchVector: doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) embedding = [0.1] * 128 - with patch("opensearchpy.helpers.bulk") as mock_bulk: + with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk: mock_bulk.return_value = ([], []) self.vector.add_texts([doc], [embedding]) diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 330ebfd54a..cdecdf41d2 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -48,3 +48,19 @@ def get_mocked_fetch_model_config( ) return MagicMock(return_value=(model_instance, model_config)) + + +def get_mocked_fetch_model_instance( + provider: str, + model: str, + mode: str, + credentials: dict, +): + mock_fetch_model_config = get_mocked_fetch_model_config( + provider=provider, + model=model, + mode=mode, + credentials=credentials, + ) + model_instance, _ = mock_fetch_model_config() + return MagicMock(return_value=model_instance) diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 0473d9832a..e0f2363799 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -7,8 +7,11 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.workflow.node_factory import DifyNodeFactory +from core.helper.ssrf_proxy import ssrf_proxy +from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus +from core.workflow.file.file_manager import file_manager from core.workflow.graph import Graph from core.workflow.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -76,6 +79,9 @@ def init_http_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, + http_client=ssrf_proxy, + tool_file_manager_factory=ToolFileManager, + file_manager=file_manager, ) return node @@ -229,6 +235,8 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -716,6 +724,9 @@ def test_nested_object_variable_selector(setup_http_mock): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, + http_client=ssrf_proxy, + tool_file_manager_factory=ToolFileManager, + file_manager=file_manager, ) result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 1b341e8f21..b5b0fb5334 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -5,13 +5,13 @@ from collections.abc import Generator from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db @@ -67,21 +67,14 @@ def init_llm_node(config: dict) -> LLMNode: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - # Create node factory - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - node = LLMNode( id=str(uuid.uuid4()), config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - credentials_provider=MagicMock(), - model_factory=MagicMock(), + credentials_provider=MagicMock(spec=CredentialsProvider), + model_factory=MagicMock(spec=ModelFactory), + model_instance=MagicMock(spec=ModelInstance), ) return node @@ -116,8 +109,7 @@ def test_execute_llm(): db.session.close = MagicMock() - # Mock the _fetch_model_config to avoid database calls - def mock_fetch_model_config(*_args, **_kwargs): + def build_mock_model_instance() -> MagicMock: from decimal import Decimal from unittest.mock import MagicMock @@ -125,7 +117,20 @@ def test_execute_llm(): from core.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance - mock_model_instance = MagicMock() + mock_model_instance = MagicMock(spec=ModelInstance) + mock_model_instance.provider = "openai" + mock_model_instance.model_name = "gpt-3.5-turbo" + mock_model_instance.credentials = {} + mock_model_instance.parameters = {} + mock_model_instance.stop = [] + mock_model_instance.model_type_instance = MagicMock() + mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock( + model_properties={}, + parameter_rules=[], + features=[], + ) + mock_model_instance.provider_model_bundle = MagicMock() + mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom" mock_usage = LLMUsage( prompt_tokens=30, prompt_unit_price=Decimal("0.001"), @@ -149,14 +154,7 @@ def test_execute_llm(): ) mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create mock model config - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.parameters = {} - - return mock_model_instance, mock_model_config + return mock_model_instance # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_1(**_kwargs): @@ -167,10 +165,9 @@ def test_execute_llm(): UserPromptMessage(content="what's the weather today?"), ], [] - with ( - patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config), - patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1), - ): + node._model_instance = build_mock_model_instance() + + with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1): # execute node result = node._run() assert isinstance(result, Generator) @@ -228,8 +225,7 @@ def test_execute_llm_with_jinja2(): # Mock db.session.close() db.session.close = MagicMock() - # Mock the _fetch_model_config method - def mock_fetch_model_config(*_args, **_kwargs): + def build_mock_model_instance() -> MagicMock: from decimal import Decimal from unittest.mock import MagicMock @@ -237,7 +233,20 @@ def test_execute_llm_with_jinja2(): from core.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance - mock_model_instance = MagicMock() + mock_model_instance = MagicMock(spec=ModelInstance) + mock_model_instance.provider = "openai" + mock_model_instance.model_name = "gpt-3.5-turbo" + mock_model_instance.credentials = {} + mock_model_instance.parameters = {} + mock_model_instance.stop = [] + mock_model_instance.model_type_instance = MagicMock() + mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock( + model_properties={}, + parameter_rules=[], + features=[], + ) + mock_model_instance.provider_model_bundle = MagicMock() + mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom" mock_usage = LLMUsage( prompt_tokens=30, prompt_unit_price=Decimal("0.001"), @@ -261,14 +270,7 @@ def test_execute_llm_with_jinja2(): ) mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create mock model config - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.parameters = {} - - return mock_model_instance, mock_model_config + return mock_model_instance # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_2(**_kwargs): @@ -279,10 +281,9 @@ def test_execute_llm_with_jinja2(): UserPromptMessage(content="what's the weather today?"), ], [] - with ( - patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config), - patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2), - ): + node._model_instance = build_mock_model_instance() + + with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2): # execute node result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 88edc4f9b3..e791f12393 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -4,18 +4,17 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory +from core.model_manager import ModelInstance from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom -from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config +from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock @@ -72,14 +71,6 @@ def init_parameter_extractor_node(config: dict): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - # Create node factory - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - node = ParameterExtractorNode( id=str(uuid.uuid4()), config=config, @@ -87,6 +78,7 @@ def init_parameter_extractor_node(config: dict): graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), + model_instance=MagicMock(spec=ModelInstance), ) return node @@ -116,12 +108,12 @@ def test_function_calling_parameter_extractor(setup_model_mock): } ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo", mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() db.session.close = MagicMock() result = node._run() @@ -157,12 +149,12 @@ def test_instructions(setup_model_mock): }, ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo", mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() db.session.close = MagicMock() result = node._run() @@ -207,12 +199,12 @@ def test_chat_parameter_extractor(setup_model_mock): }, ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo", mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() db.session.close = MagicMock() result = node._run() @@ -258,12 +250,12 @@ def test_completion_parameter_extractor(setup_model_mock): }, ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo-instruct", mode="completion", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() db.session.close = MagicMock() result = node._run() @@ -383,12 +375,12 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): }, ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo", mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() # Test the mock before running the actual test monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory")) db.session.close = MagicMock() diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py new file mode 100644 index 0000000000..c08ea2a93b --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -0,0 +1,1285 @@ +""" +Comprehensive integration tests for DocumentService status management methods. + +This module contains extensive integration tests for the DocumentService class, +specifically focusing on document status management operations including +pause, recover, retry, batch updates, and renaming. +""" + +import datetime +import json +from unittest.mock import create_autospec, patch +from uuid import uuid4 + +import pytest + +from models import Account +from models.dataset import Dataset, Document +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.dataset_service import DocumentService +from services.errors.document import DocumentIndexingError + +FIXED_TIME = datetime.datetime(2023, 1, 1, 12, 0, 0) + + +class DocumentStatusTestDataFactory: + """ + Factory class for creating real test data and helper doubles for document status tests. + + This factory provides static methods to create persisted entities for SQL + assertions and lightweight doubles for collaborator patches. + + The factory methods help maintain consistency across tests and reduce + code duplication when setting up test scenarios. + """ + + @staticmethod + def create_document( + db_session_with_containers, + document_id: str | None = None, + dataset_id: str | None = None, + tenant_id: str | None = None, + name: str = "Test Document", + indexing_status: str = "completed", + is_paused: bool = False, + enabled: bool = True, + archived: bool = False, + paused_by: str | None = None, + paused_at: datetime.datetime | None = None, + data_source_type: str = "upload_file", + data_source_info: dict | None = None, + doc_metadata: dict | None = None, + **kwargs, + ) -> Document: + """ + Create a persisted Document with specified attributes. + + Args: + document_id: Unique identifier for the document + dataset_id: Dataset identifier + tenant_id: Tenant identifier + name: Document name + indexing_status: Current indexing status + is_paused: Whether document is paused + enabled: Whether document is enabled + archived: Whether document is archived + paused_by: ID of user who paused the document + paused_at: Timestamp when document was paused + data_source_type: Type of data source + data_source_info: Data source information dictionary + doc_metadata: Document metadata dictionary + **kwargs: Additional attributes to set on the entity + + Returns: + Persisted Document instance + """ + tenant_id = tenant_id or str(uuid4()) + dataset_id = dataset_id or str(uuid4()) + document_id = document_id or str(uuid4()) + created_by = kwargs.pop("created_by", str(uuid4())) + position = kwargs.pop("position", 1) + + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=position, + data_source_type=data_source_type, + data_source_info=json.dumps(data_source_info or {}), + batch=f"batch-{uuid4()}", + name=name, + created_from="web", + created_by=created_by, + doc_form="text_model", + ) + document.id = document_id + document.indexing_status = indexing_status + document.is_paused = is_paused + document.enabled = enabled + document.archived = archived + document.paused_by = paused_by + document.paused_at = paused_at + document.doc_metadata = doc_metadata or {} + if indexing_status == "completed" and "completed_at" not in kwargs: + document.completed_at = FIXED_TIME + + for key, value in kwargs.items(): + setattr(document, key, value) + + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + @staticmethod + def create_dataset( + db_session_with_containers, + dataset_id: str | None = None, + tenant_id: str | None = None, + name: str = "Test Dataset", + built_in_field_enabled: bool = False, + **kwargs, + ) -> Dataset: + """ + Create a persisted Dataset with specified attributes. + + Args: + dataset_id: Unique identifier for the dataset + tenant_id: Tenant identifier + name: Dataset name + built_in_field_enabled: Whether built-in fields are enabled + **kwargs: Additional attributes to set on the entity + + Returns: + Persisted Dataset instance + """ + tenant_id = tenant_id or str(uuid4()) + dataset_id = dataset_id or str(uuid4()) + created_by = kwargs.pop("created_by", str(uuid4())) + + dataset = Dataset( + tenant_id=tenant_id, + name=name, + data_source_type="upload_file", + created_by=created_by, + ) + dataset.id = dataset_id + dataset.built_in_field_enabled = built_in_field_enabled + + for key, value in kwargs.items(): + setattr(dataset, key, value) + + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_user_mock( + user_id: str | None = None, + tenant_id: str | None = None, + **kwargs, + ) -> Account: + """ + Create a mock user (Account) with specified attributes. + + Args: + user_id: Unique identifier for the user + tenant_id: Tenant identifier + **kwargs: Additional attributes to set on the mock + + Returns: + Mock object configured as an Account instance + """ + user = create_autospec(Account, instance=True) + user.id = user_id or str(uuid4()) + user.current_tenant_id = tenant_id or str(uuid4()) + for key, value in kwargs.items(): + setattr(user, key, value) + return user + + @staticmethod + def create_upload_file( + db_session_with_containers, + tenant_id: str, + created_by: str, + file_id: str | None = None, + name: str = "test_file.pdf", + **kwargs, + ) -> UploadFile: + """ + Create a persisted UploadFile with specified attributes. + + Args: + file_id: Unique identifier for the file + name: File name + **kwargs: Additional attributes to set on the entity + + Returns: + Persisted UploadFile instance + """ + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key=f"uploads/{uuid4()}", + name=name, + size=128, + extension="pdf", + mime_type="application/pdf", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=created_by, + created_at=FIXED_TIME, + used=False, + ) + upload_file.id = file_id or str(uuid4()) + for key, value in kwargs.items(): + setattr(upload_file, key, value) + + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + +class TestDocumentServicePauseDocument: + """ + Comprehensive integration tests for DocumentService.pause_document method. + + This test class covers the document pause functionality, which allows + users to pause the indexing process for documents that are currently + being indexed. + + The pause_document method: + 1. Validates document is in a pausable state + 2. Sets is_paused flag to True + 3. Records paused_by and paused_at + 4. Commits changes to database + 5. Sets pause flag in Redis cache + + Test scenarios include: + - Pausing documents in various indexing states + - Error handling for invalid states + - Redis cache flag setting + - Current user validation + """ + + @pytest.fixture + def mock_document_service_dependencies(self): + """ + Mock document service dependencies for testing. + + Provides mocked dependencies including: + - current_user context + - Database session + - Redis client + - Current time utilities + """ + with ( + patch( + "services.dataset_service.current_user", create_autospec(Account, instance=True) + ) as mock_current_user, + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, + ): + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + user_id = str(uuid4()) + mock_naive_utc_now.return_value = current_time + mock_current_user.id = user_id + + yield { + "current_user": mock_current_user, + "redis_client": mock_redis, + "naive_utc_now": mock_naive_utc_now, + "current_time": current_time, + "user_id": user_id, + } + + def test_pause_document_waiting_state_success(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test successful pause of document in waiting state. + + Verifies that when a document is in waiting state, it can be + paused successfully. + + This test ensures: + - Document state is validated + - is_paused flag is set + - paused_by and paused_at are recorded + - Changes are committed + - Redis cache flag is set + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + indexing_status="waiting", + is_paused=False, + ) + + # Act + DocumentService.pause_document(document) + + # Assert + db_session_with_containers.refresh(document) + assert document.is_paused is True + assert document.paused_by == mock_document_service_dependencies["user_id"] + assert document.paused_at == mock_document_service_dependencies["current_time"] + + expected_cache_key = f"document_{document.id}_is_paused" + mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True") + + def test_pause_document_indexing_state_success( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test successful pause of document in indexing state. + + Verifies that when a document is actively being indexed, it can + be paused successfully. + + This test ensures: + - Document in indexing state can be paused + - All pause operations complete correctly + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + indexing_status="indexing", + is_paused=False, + ) + + # Act + DocumentService.pause_document(document) + + # Assert + db_session_with_containers.refresh(document) + assert document.is_paused is True + assert document.paused_by == mock_document_service_dependencies["user_id"] + + def test_pause_document_parsing_state_success(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test successful pause of document in parsing state. + + Verifies that when a document is being parsed, it can be paused. + + This test ensures: + - Document in parsing state can be paused + - Pause operations work for all valid states + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + indexing_status="parsing", + is_paused=False, + ) + + # Act + DocumentService.pause_document(document) + + # Assert + db_session_with_containers.refresh(document) + assert document.is_paused is True + + def test_pause_document_completed_state_error(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test error when trying to pause completed document. + + Verifies that when a document is already completed, it cannot + be paused and a DocumentIndexingError is raised. + + This test ensures: + - Completed documents cannot be paused + - Error type is correct + - No database operations are performed + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + indexing_status="completed", + is_paused=False, + ) + + # Act & Assert + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(document) + + db_session_with_containers.refresh(document) + assert document.is_paused is False + + def test_pause_document_error_state_error(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test error when trying to pause document in error state. + + Verifies that when a document is in error state, it cannot be + paused and a DocumentIndexingError is raised. + + This test ensures: + - Error state documents cannot be paused + - Error type is correct + - No database operations are performed + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + indexing_status="error", + is_paused=False, + ) + + # Act & Assert + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(document) + + db_session_with_containers.refresh(document) + assert document.is_paused is False + + +class TestDocumentServiceRecoverDocument: + """ + Comprehensive integration tests for DocumentService.recover_document method. + + This test class covers the document recovery functionality, which allows + users to resume indexing for documents that were previously paused. + + The recover_document method: + 1. Validates document is paused + 2. Clears is_paused flag + 3. Clears paused_by and paused_at + 4. Commits changes to database + 5. Deletes pause flag from Redis cache + 6. Triggers recovery task + + Test scenarios include: + - Recovering paused documents + - Error handling for non-paused documents + - Redis cache flag deletion + - Recovery task triggering + """ + + @pytest.fixture + def mock_document_service_dependencies(self): + """ + Mock document service dependencies for testing. + + Provides mocked dependencies including: + - Database session + - Redis client + - Recovery task + """ + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.recover_document_indexing_task") as mock_task, + ): + yield { + "redis_client": mock_redis, + "recover_task": mock_task, + } + + def test_recover_document_paused_success(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test successful recovery of paused document. + + Verifies that when a document is paused, it can be recovered + successfully and indexing resumes. + + This test ensures: + - Document is validated as paused + - is_paused flag is cleared + - paused_by and paused_at are cleared + - Changes are committed + - Redis cache flag is deleted + - Recovery task is triggered + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + paused_time = FIXED_TIME + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + indexing_status="indexing", + is_paused=True, + paused_by=str(uuid4()), + paused_at=paused_time, + ) + + # Act + DocumentService.recover_document(document) + + # Assert + db_session_with_containers.refresh(document) + assert document.is_paused is False + assert document.paused_by is None + assert document.paused_at is None + + expected_cache_key = f"document_{document.id}_is_paused" + mock_document_service_dependencies["redis_client"].delete.assert_called_once_with(expected_cache_key) + mock_document_service_dependencies["recover_task"].delay.assert_called_once_with( + document.dataset_id, document.id + ) + + def test_recover_document_not_paused_error(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test error when trying to recover non-paused document. + + Verifies that when a document is not paused, it cannot be + recovered and a DocumentIndexingError is raised. + + This test ensures: + - Non-paused documents cannot be recovered + - Error type is correct + - No database operations are performed + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + indexing_status="indexing", + is_paused=False, + ) + + # Act & Assert + with pytest.raises(DocumentIndexingError): + DocumentService.recover_document(document) + + db_session_with_containers.refresh(document) + assert document.is_paused is False + + +class TestDocumentServiceRetryDocument: + """ + Comprehensive integration tests for DocumentService.retry_document method. + + This test class covers the document retry functionality, which allows + users to retry failed document indexing operations. + + The retry_document method: + 1. Validates documents are not already being retried + 2. Sets retry flag in Redis cache + 3. Resets document indexing_status to waiting + 4. Commits changes to database + 5. Triggers retry task + + Test scenarios include: + - Retrying single document + - Retrying multiple documents + - Error handling for concurrent retries + - Current user validation + - Retry task triggering + """ + + @pytest.fixture + def mock_document_service_dependencies(self): + """ + Mock document service dependencies for testing. + + Provides mocked dependencies including: + - current_user context + - Database session + - Redis client + - Retry task + """ + with ( + patch( + "services.dataset_service.current_user", create_autospec(Account, instance=True) + ) as mock_current_user, + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.retry_document_indexing_task") as mock_task, + ): + user_id = str(uuid4()) + mock_current_user.id = user_id + + yield { + "current_user": mock_current_user, + "redis_client": mock_redis, + "retry_task": mock_task, + "user_id": user_id, + } + + def test_retry_document_single_success(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test successful retry of single document. + + Verifies that when a document is retried, the retry process + completes successfully. + + This test ensures: + - Retry flag is checked + - Document status is reset to waiting + - Changes are committed + - Retry flag is set in Redis + - Retry task is triggered + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + indexing_status="error", + ) + + mock_document_service_dependencies["redis_client"].get.return_value = None + + # Act + DocumentService.retry_document(dataset.id, [document]) + + # Assert + db_session_with_containers.refresh(document) + assert document.indexing_status == "waiting" + + expected_cache_key = f"document_{document.id}_is_retried" + mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1) + mock_document_service_dependencies["retry_task"].delay.assert_called_once_with( + dataset.id, [document.id], mock_document_service_dependencies["user_id"] + ) + + def test_retry_document_multiple_success(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test successful retry of multiple documents. + + Verifies that when multiple documents are retried, all retry + processes complete successfully. + + This test ensures: + - Multiple documents can be retried + - All documents are processed + - Retry task is triggered with all document IDs + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document1 = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + indexing_status="error", + ) + document2 = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + indexing_status="error", + position=2, + ) + + mock_document_service_dependencies["redis_client"].get.return_value = None + + # Act + DocumentService.retry_document(dataset.id, [document1, document2]) + + # Assert + db_session_with_containers.refresh(document1) + db_session_with_containers.refresh(document2) + assert document1.indexing_status == "waiting" + assert document2.indexing_status == "waiting" + + mock_document_service_dependencies["retry_task"].delay.assert_called_once_with( + dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"] + ) + + def test_retry_document_concurrent_retry_error( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test error when document is already being retried. + + Verifies that when a document is already being retried, a new + retry attempt raises a ValueError. + + This test ensures: + - Concurrent retries are prevented + - Error message is clear + - Error type is correct + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + indexing_status="error", + ) + + mock_document_service_dependencies["redis_client"].get.return_value = "1" + + # Act & Assert + with pytest.raises(ValueError, match="Document is being retried, please try again later"): + DocumentService.retry_document(dataset.id, [document]) + + db_session_with_containers.refresh(document) + assert document.indexing_status == "error" + + def test_retry_document_missing_current_user_error( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test error when current_user is missing. + + Verifies that when current_user is None or has no ID, a ValueError + is raised. + + This test ensures: + - Current user validation works correctly + - Error message is clear + - Error type is correct + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + indexing_status="error", + ) + + mock_document_service_dependencies["redis_client"].get.return_value = None + mock_document_service_dependencies["current_user"].id = None + + # Act & Assert + with pytest.raises(ValueError, match="Current user or current user id not found"): + DocumentService.retry_document(dataset.id, [document]) + + +class TestDocumentServiceBatchUpdateDocumentStatus: + """ + Comprehensive integration tests for DocumentService.batch_update_document_status method. + + This test class covers the batch document status update functionality, + which allows users to update the status of multiple documents at once. + + The batch_update_document_status method: + 1. Validates action parameter + 2. Validates all documents + 3. Checks if documents are being indexed + 4. Prepares updates for each document + 5. Applies all updates in a single transaction + 6. Triggers async tasks + 7. Sets Redis cache flags + + Test scenarios include: + - Batch enabling documents + - Batch disabling documents + - Batch archiving documents + - Batch unarchiving documents + - Handling empty lists + - Document indexing check + - Transaction rollback on errors + """ + + @pytest.fixture + def mock_document_service_dependencies(self): + """ + Mock document service dependencies for testing. + + Provides mocked dependencies including: + - get_document method + - Database session + - Redis client + - Async tasks + """ + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.add_document_to_index_task") as mock_add_task, + patch("services.dataset_service.remove_document_from_index_task") as mock_remove_task, + patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, + ): + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_naive_utc_now.return_value = current_time + + yield { + "redis_client": mock_redis, + "add_task": mock_add_task, + "remove_task": mock_remove_task, + "naive_utc_now": mock_naive_utc_now, + "current_time": current_time, + } + + def test_batch_update_document_status_enable_success( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test successful batch enabling of documents. + + Verifies that when documents are enabled in batch, all operations + complete successfully. + + This test ensures: + - Documents are retrieved correctly + - Enabled flag is set + - Async tasks are triggered + - Redis cache flags are set + - Transaction is committed + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + user = DocumentStatusTestDataFactory.create_user_mock(tenant_id=dataset.tenant_id) + document1 = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + enabled=False, + indexing_status="completed", + ) + document2 = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + enabled=False, + indexing_status="completed", + position=2, + ) + document_ids = [document1.id, document2.id] + + mock_document_service_dependencies["redis_client"].get.return_value = None + + # Act + DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) + + # Assert + db_session_with_containers.refresh(document1) + db_session_with_containers.refresh(document2) + assert document1.enabled is True + assert document2.enabled is True + assert mock_document_service_dependencies["add_task"].delay.call_count == 2 + + def test_batch_update_document_status_disable_success( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test successful batch disabling of documents. + + Verifies that when documents are disabled in batch, all operations + complete successfully. + + This test ensures: + - Documents are retrieved correctly + - Enabled flag is cleared + - Disabled_at and disabled_by are set + - Async tasks are triggered + - Transaction is committed + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + user = DocumentStatusTestDataFactory.create_user_mock(tenant_id=dataset.tenant_id) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + enabled=True, + indexing_status="completed", + completed_at=FIXED_TIME, + ) + document_ids = [document.id] + + mock_document_service_dependencies["redis_client"].get.return_value = None + + # Act + DocumentService.batch_update_document_status(dataset, document_ids, "disable", user) + + # Assert + db_session_with_containers.refresh(document) + assert document.enabled is False + assert document.disabled_at == mock_document_service_dependencies["current_time"] + assert document.disabled_by == user.id + mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_document_status_archive_success( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test successful batch archiving of documents. + + Verifies that when documents are archived in batch, all operations + complete successfully. + + This test ensures: + - Documents are retrieved correctly + - Archived flag is set + - Archived_at and archived_by are set + - Async tasks are triggered for enabled documents + - Transaction is committed + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + user = DocumentStatusTestDataFactory.create_user_mock(tenant_id=dataset.tenant_id) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + archived=False, + enabled=True, + indexing_status="completed", + ) + document_ids = [document.id] + + mock_document_service_dependencies["redis_client"].get.return_value = None + + # Act + DocumentService.batch_update_document_status(dataset, document_ids, "archive", user) + + # Assert + db_session_with_containers.refresh(document) + assert document.archived is True + assert document.archived_at == mock_document_service_dependencies["current_time"] + assert document.archived_by == user.id + mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_document_status_unarchive_success( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test successful batch unarchiving of documents. + + Verifies that when documents are unarchived in batch, all operations + complete successfully. + + This test ensures: + - Documents are retrieved correctly + - Archived flag is cleared + - Archived_at and archived_by are cleared + - Async tasks are triggered for enabled documents + - Transaction is committed + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + user = DocumentStatusTestDataFactory.create_user_mock(tenant_id=dataset.tenant_id) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + archived=True, + enabled=True, + indexing_status="completed", + ) + document_ids = [document.id] + + mock_document_service_dependencies["redis_client"].get.return_value = None + + # Act + DocumentService.batch_update_document_status(dataset, document_ids, "un_archive", user) + + # Assert + db_session_with_containers.refresh(document) + assert document.archived is False + assert document.archived_at is None + assert document.archived_by is None + mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_document_status_empty_list( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test handling of empty document list. + + Verifies that when an empty list is provided, the method returns + early without performing any operations. + + This test ensures: + - Empty lists are handled gracefully + - No database operations are performed + - No errors are raised + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + user = DocumentStatusTestDataFactory.create_user_mock(tenant_id=dataset.tenant_id) + document_ids = [] + + # Act + DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) + + # Assert + mock_document_service_dependencies["add_task"].delay.assert_not_called() + mock_document_service_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_document_status_document_indexing_error( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test error when document is being indexed. + + Verifies that when a document is currently being indexed, a + DocumentIndexingError is raised. + + This test ensures: + - Indexing documents cannot be updated + - Error message is clear + - Error type is correct + """ + # Arrange + dataset = DocumentStatusTestDataFactory.create_dataset(db_session_with_containers) + user = DocumentStatusTestDataFactory.create_user_mock(tenant_id=dataset.tenant_id) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + document_id=str(uuid4()), + indexing_status="completed", + ) + document_ids = [document.id] + + mock_document_service_dependencies["redis_client"].get.return_value = "1" + + # Act & Assert + with pytest.raises(DocumentIndexingError, match="is being indexed"): + DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) + + +class TestDocumentServiceRenameDocument: + """ + Comprehensive integration tests for DocumentService.rename_document method. + + This test class covers the document renaming functionality, which allows + users to rename documents for better organization. + + The rename_document method: + 1. Validates dataset exists + 2. Validates document exists + 3. Validates tenant permission + 4. Updates document name + 5. Updates metadata if built-in fields enabled + 6. Updates associated upload file name + 7. Commits changes + + Test scenarios include: + - Successful document renaming + - Dataset not found error + - Document not found error + - Permission validation + - Metadata updates + - Upload file name updates + """ + + @pytest.fixture + def mock_document_service_dependencies(self): + """ + Mock document service dependencies for testing. + + Provides mocked dependencies including: + - DatasetService.get_dataset + - DocumentService.get_document + - current_user context + - Database session + """ + with patch( + "services.dataset_service.current_user", create_autospec(Account, instance=True) + ) as mock_current_user: + mock_current_user.current_tenant_id = str(uuid4()) + + yield { + "current_user": mock_current_user, + } + + def test_rename_document_success(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test successful document renaming. + + Verifies that when all validation passes, a document is renamed + successfully. + + This test ensures: + - Dataset is retrieved correctly + - Document is retrieved correctly + - Document name is updated + - Changes are committed + """ + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "New Document Name" + tenant_id = mock_document_service_dependencies["current_user"].current_tenant_id + + dataset = DocumentStatusTestDataFactory.create_dataset( + db_session_with_containers, dataset_id=dataset_id, tenant_id=tenant_id + ) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=tenant_id, + indexing_status="completed", + ) + + # Act + result = DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + assert result == document + assert document.name == new_name + + def test_rename_document_with_built_in_fields(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test document renaming with built-in fields enabled. + + Verifies that when built-in fields are enabled, the document + metadata is also updated. + + This test ensures: + - Document name is updated + - Metadata is updated with new name + - Built-in field is set correctly + """ + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "New Document Name" + tenant_id = mock_document_service_dependencies["current_user"].current_tenant_id + + dataset = DocumentStatusTestDataFactory.create_dataset( + db_session_with_containers, + dataset_id=dataset_id, + tenant_id=tenant_id, + built_in_field_enabled=True, + ) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + indexing_status="completed", + ) + + # Act + DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + assert document.name == new_name + assert "document_name" in document.doc_metadata + assert document.doc_metadata["document_name"] == new_name + assert document.doc_metadata["existing_key"] == "existing_value" + + def test_rename_document_with_upload_file(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test document renaming with associated upload file. + + Verifies that when a document has an associated upload file, + the file name is also updated. + + This test ensures: + - Document name is updated + - Upload file name is updated + - Database query is executed correctly + """ + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "New Document Name" + file_id = str(uuid4()) + tenant_id = mock_document_service_dependencies["current_user"].current_tenant_id + + dataset = DocumentStatusTestDataFactory.create_dataset( + db_session_with_containers, dataset_id=dataset_id, tenant_id=tenant_id + ) + upload_file = DocumentStatusTestDataFactory.create_upload_file( + db_session_with_containers, + tenant_id=tenant_id, + created_by=str(uuid4()), + file_id=file_id, + name="old_name.pdf", + ) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=tenant_id, + data_source_info={"upload_file_id": upload_file.id}, + indexing_status="completed", + ) + + # Act + DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(upload_file) + assert document.name == new_name + assert upload_file.name == new_name + + def test_rename_document_dataset_not_found_error( + self, db_session_with_containers, mock_document_service_dependencies + ): + """ + Test error when dataset is not found. + + Verifies that when the dataset ID doesn't exist, a ValueError + is raised. + + This test ensures: + - Dataset existence is validated + - Error message is clear + - Error type is correct + """ + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "New Document Name" + + # Act & Assert + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document(dataset_id, document_id, new_name) + + def test_rename_document_not_found_error(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test error when document is not found. + + Verifies that when the document ID doesn't exist, a ValueError + is raised. + + This test ensures: + - Document existence is validated + - Error message is clear + - Error type is correct + """ + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "New Document Name" + + dataset = DocumentStatusTestDataFactory.create_dataset( + db_session_with_containers, + dataset_id=dataset_id, + tenant_id=mock_document_service_dependencies["current_user"].current_tenant_id, + ) + + # Act & Assert + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, document_id, new_name) + + def test_rename_document_permission_error(self, db_session_with_containers, mock_document_service_dependencies): + """ + Test error when user lacks permission. + + Verifies that when the user is in a different tenant, a ValueError + is raised. + + This test ensures: + - Tenant permission is validated + - Error message is clear + - Error type is correct + """ + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "New Document Name" + current_tenant_id = mock_document_service_dependencies["current_user"].current_tenant_id + + dataset = DocumentStatusTestDataFactory.create_dataset( + db_session_with_containers, + dataset_id=dataset_id, + tenant_id=current_tenant_id, + ) + document = DocumentStatusTestDataFactory.create_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=str(uuid4()), + indexing_status="completed", + ) + + # Act & Assert + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, new_name) diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index fb6304a59e..e7cc140582 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -19,14 +19,14 @@ class TestAgentService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client, - patch("services.agent_service.ToolManager") as mock_tool_manager, - patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager, + patch("services.agent_service.PluginAgentClient", autospec=True) as mock_plugin_agent_client, + patch("services.agent_service.ToolManager", autospec=True) as mock_tool_manager, + patch("services.agent_service.AgentConfigManager", autospec=True) as mock_agent_config_manager, patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, - patch("services.app_service.FeatureService") as mock_feature_service, - patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, - patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.app_service.FeatureService", autospec=True) as mock_feature_service, + patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service, + patch("services.app_service.ModelManager", autospec=True) as mock_model_manager, + patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, ): # Setup default mock returns for agent service mock_plugin_agent_client_instance = mock_plugin_agent_client.return_value diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 81bfa0ea20..8544d23cdf 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -18,18 +18,22 @@ class TestAppGenerateService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.billing_service.BillingService") as mock_billing_service, - patch("services.app_generate_service.WorkflowService") as mock_workflow_service, - patch("services.app_generate_service.RateLimit") as mock_rate_limit, - patch("services.app_generate_service.CompletionAppGenerator") as mock_completion_generator, - patch("services.app_generate_service.ChatAppGenerator") as mock_chat_generator, - patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator, - patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator, - patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator, - patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator, - patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.app_generate_service.dify_config") as mock_dify_config, - patch("configs.dify_config") as mock_global_dify_config, + patch("services.billing_service.BillingService", autospec=True) as mock_billing_service, + patch("services.app_generate_service.WorkflowService", autospec=True) as mock_workflow_service, + patch("services.app_generate_service.RateLimit", autospec=True) as mock_rate_limit, + patch("services.app_generate_service.CompletionAppGenerator", autospec=True) as mock_completion_generator, + patch("services.app_generate_service.ChatAppGenerator", autospec=True) as mock_chat_generator, + patch("services.app_generate_service.AgentChatAppGenerator", autospec=True) as mock_agent_chat_generator, + patch( + "services.app_generate_service.AdvancedChatAppGenerator", autospec=True + ) as mock_advanced_chat_generator, + patch("services.app_generate_service.WorkflowAppGenerator", autospec=True) as mock_workflow_generator, + patch( + "services.app_generate_service.MessageBasedAppGenerator", autospec=True + ) as mock_message_based_generator, + patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, + patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config, + patch("configs.dify_config", autospec=True) as mock_global_dify_config, ): # Setup default mock returns for billing service mock_billing_service.update_tenant_feature_plan_usage.return_value = { @@ -983,7 +987,7 @@ class TestAppGenerateService: } # Execute the method under test - with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params: + with patch("services.app_generate_service.AppExecutionParams", autospec=True) as mock_exec_params: mock_payload = MagicMock() mock_payload.workflow_run_id = fake.uuid4() mock_payload.model_dump_json.return_value = "{}" diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index ba8e89feb1..5f64e6f674 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -1034,3 +1034,34 @@ class TestConversationServiceExport: # Step 2: Async cleanup task triggered # The Celery task will handle cleanup of messages, annotations, etc. mock_delete_task.delay.assert_called_once_with(conversation_id) + + @patch("services.conversation_service.delete_conversation_related_data") + def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers): + """ + Test deletion is denied when conversation belongs to a different account. + """ + # Arrange + app_model, owner_account = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + _, other_account = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, + app_model, + owner_account, + ) + + # Act & Assert + with pytest.raises(ConversationNotExistsError): + ConversationService.delete( + app_model=app_model, + conversation_id=conversation.id, + user=other_account, + ) + + # Verify no deletion and no async cleanup trigger + not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id)) + assert not_deleted is not None + mock_delete_task.delay.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py new file mode 100644 index 0000000000..f05c47913e --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -0,0 +1,418 @@ +"""Integration tests for SQL-oriented DatasetService scenarios. + +This suite migrates SQL-backed behaviors from the old unit suite to real +container-backed integration tests. The tests exercise real ORM persistence and +only patch non-DB collaborators when needed. +""" + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from core.model_runtime.entities.model_entities import ModelType +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel +from services.errors.dataset import DatasetNameDuplicateError + + +class DatasetServiceIntegrationDataFactory: + """Factory for creating real database entities used by integration tests.""" + + @staticmethod + def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + """Create an account and tenant, then bind the account as current tenant member.""" + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db.session.add_all([account, tenant]) + db.session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db.session.add(join) + db.session.flush() + + # Keep tenant context on the in-memory user without opening a separate session. + account.role = role + account._current_tenant = tenant + return account, tenant + + @staticmethod + def create_dataset( + tenant_id: str, + created_by: str, + name: str = "Test Dataset", + description: str | None = "Test description", + provider: str = "vendor", + indexing_technique: str | None = "high_quality", + permission: str = DatasetPermissionEnum.ONLY_ME, + retrieval_model: dict | None = None, + embedding_model_provider: str | None = None, + embedding_model: str | None = None, + collection_binding_id: str | None = None, + chunk_structure: str | None = None, + ) -> Dataset: + """Create a dataset record with configurable SQL fields.""" + dataset = Dataset( + tenant_id=tenant_id, + name=name, + description=description, + data_source_type="upload_file", + indexing_technique=indexing_technique, + created_by=created_by, + provider=provider, + permission=permission, + retrieval_model=retrieval_model, + embedding_model_provider=embedding_model_provider, + embedding_model=embedding_model, + collection_binding_id=collection_binding_id, + chunk_structure=chunk_structure, + ) + db.session.add(dataset) + db.session.flush() + return dataset + + @staticmethod + def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document: + """Create a document row belonging to the given dataset.""" + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + data_source_info='{"upload_file_id": "upload-file-id"}', + batch=str(uuid4()), + name=name, + created_from="web", + created_by=created_by, + indexing_status="completed", + doc_form="text_model", + ) + db.session.add(document) + db.session.flush() + return document + + @staticmethod + def create_embedding_model(provider: str = "openai", model_name: str = "text-embedding-ada-002") -> Mock: + """Create a fake embedding model object for external provider boundary patching.""" + embedding_model = Mock() + embedding_model.provider = provider + embedding_model.model_name = model_name + return embedding_model + + +class TestDatasetServiceCreateDataset: + """Integration coverage for DatasetService.create_empty_dataset.""" + + def test_create_internal_dataset_basic_success(self, db_session_with_containers): + """Create a basic internal dataset with minimal configuration.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + + # Act + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Basic Internal Dataset", + description="Test description", + indexing_technique=None, + account=account, + ) + + # Assert + created_dataset = db.session.get(Dataset, result.id) + assert created_dataset is not None + assert created_dataset.provider == "vendor" + assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME + assert created_dataset.embedding_model_provider is None + assert created_dataset.embedding_model is None + + def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers): + """Create an internal dataset with economy indexing and no embedding model.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + + # Act + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Economy Dataset", + description=None, + indexing_technique="economy", + account=account, + ) + + # Assert + db.session.refresh(result) + assert result.indexing_technique == "economy" + assert result.embedding_model_provider is None + assert result.embedding_model is None + + def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers): + """Create a high-quality dataset and persist embedding model settings.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() + + # Act + with patch("services.dataset_service.ModelManager") as mock_model_manager: + mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model + + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="High Quality Dataset", + description=None, + indexing_technique="high_quality", + account=account, + ) + + # Assert + db.session.refresh(result) + assert result.indexing_technique == "high_quality" + assert result.embedding_model_provider == embedding_model.provider + assert result.embedding_model == embedding_model.model_name + mock_model_manager.return_value.get_default_model_instance.assert_called_once_with( + tenant_id=tenant.id, + model_type=ModelType.TEXT_EMBEDDING, + ) + + def test_create_dataset_duplicate_name_error(self, db_session_with_containers): + """Raise duplicate-name error when the same tenant already has the name.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + name="Duplicate Dataset", + indexing_technique=None, + ) + + # Act / Assert + with pytest.raises(DatasetNameDuplicateError): + DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Duplicate Dataset", + description=None, + indexing_technique=None, + account=account, + ) + + def test_create_external_dataset_success(self, db_session_with_containers): + """Create an external dataset and persist external knowledge binding.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + external_knowledge_api_id = str(uuid4()) + external_knowledge_id = "knowledge-123" + + # Act + with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api: + mock_get_api.return_value = Mock(id=external_knowledge_api_id) + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="External Dataset", + description=None, + indexing_technique=None, + account=account, + provider="external", + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id=external_knowledge_id, + ) + + # Assert + binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() + assert result.provider == "external" + assert binding is not None + assert binding.external_knowledge_id == external_knowledge_id + assert binding.external_knowledge_api_id == external_knowledge_api_id + + def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers): + """Create a high-quality dataset with retrieval/reranking settings.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name="cohere", + reranking_model_name="rerank-english-v2.0", + ), + top_k=3, + score_threshold_enabled=True, + score_threshold=0.6, + ) + + # Act + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, + ): + mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model + + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Dataset With Reranking", + description=None, + indexing_technique="high_quality", + account=account, + retrieval_model=retrieval_model, + ) + + # Assert + db.session.refresh(result) + assert result.retrieval_model == retrieval_model.model_dump() + mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0") + + +class TestDatasetServiceUpdateAndDeleteDataset: + """Integration coverage for SQL-backed update and delete behavior.""" + + def test_update_dataset_duplicate_name_error(self, db_session_with_containers): + """Reject update when target name already exists within the same tenant.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + source_dataset = DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + name="Source Dataset", + ) + DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + name="Existing Dataset", + ) + + # Act / Assert + with pytest.raises(ValueError, match="Dataset name already exists"): + DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account) + + def test_delete_dataset_with_documents_success(self, db_session_with_containers): + """Delete a dataset that already has documents.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + dataset = DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + indexing_technique="high_quality", + chunk_structure="text_model", + ) + DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id) + + # Act + with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: + result = DatasetService.delete_dataset(dataset.id, account) + + # Assert + assert result is True + assert db.session.get(Dataset, dataset.id) is None + dataset_deleted_signal.send.assert_called_once_with(dataset) + + def test_delete_empty_dataset_success(self, db_session_with_containers): + """Delete a dataset that has no documents and no indexing technique.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + dataset = DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + indexing_technique=None, + chunk_structure=None, + ) + + # Act + with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: + result = DatasetService.delete_dataset(dataset.id, account) + + # Assert + assert result is True + assert db.session.get(Dataset, dataset.id) is None + dataset_deleted_signal.send.assert_called_once_with(dataset) + + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + """Delete dataset when indexing_technique is None but doc_form path still exists.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + dataset = DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + indexing_technique=None, + chunk_structure="text_model", + ) + + # Act + with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: + result = DatasetService.delete_dataset(dataset.id, account) + + # Assert + assert result is True + assert db.session.get(Dataset, dataset.id) is None + dataset_deleted_signal.send.assert_called_once_with(dataset) + + +class TestDatasetServiceRetrievalConfiguration: + """Integration coverage for retrieval configuration persistence.""" + + def test_get_dataset_retrieval_configuration(self, db_session_with_containers): + """Return retrieval configuration that is persisted in SQL.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + retrieval_model = { + "search_method": "semantic_search", + "top_k": 5, + "score_threshold": 0.5, + "reranking_enable": True, + } + dataset = DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + retrieval_model=retrieval_model, + ) + + # Act + result = DatasetService.get_dataset(dataset.id) + + # Assert + assert result is not None + assert result.retrieval_model == retrieval_model + assert result.retrieval_model["search_method"] == "semantic_search" + assert result.retrieval_model["top_k"] == 5 + + def test_update_dataset_retrieval_configuration(self, db_session_with_containers): + """Persist retrieval configuration updates through DatasetService.update_dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + dataset = DatasetServiceIntegrationDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=account.id, + indexing_technique="high_quality", + retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0}, + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id=str(uuid4()), + ) + update_data = { + "indexing_technique": "high_quality", + "retrieval_model": { + "search_method": "full_text_search", + "top_k": 10, + "score_threshold": 0.7, + }, + } + + # Act + result = DatasetService.update_dataset(dataset.id, update_data, account) + + # Assert + db.session.refresh(dataset) + assert result.id == dataset.id + assert dataset.retrieval_model == update_data["retrieval_model"] diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 8a72331425..7c8472e819 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -17,10 +17,12 @@ class TestModelLoadBalancingService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager, - patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager, - patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory, - patch("services.model_load_balancing_service.encrypter") as mock_encrypter, + patch("services.model_load_balancing_service.ProviderManager", autospec=True) as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, + patch( + "services.model_load_balancing_service.ModelProviderFactory", autospec=True + ) as mock_model_provider_factory, + patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter, ): # Setup default mock returns mock_provider_manager_instance = mock_provider_manager.return_value diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index d57ab7428b..f7044f7d45 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -17,8 +17,8 @@ class TestModelProviderService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_provider_service.ProviderManager") as mock_provider_manager, - patch("services.model_provider_service.ModelProviderFactory") as mock_model_provider_factory, + patch("services.model_provider_service.ProviderManager", autospec=True) as mock_provider_manager, + patch("services.model_provider_service.ModelProviderFactory", autospec=True) as mock_model_provider_factory, ): # Setup default mock returns mock_provider_manager.return_value.get_configurations.return_value = MagicMock() @@ -526,7 +526,9 @@ class TestModelProviderService: # Act: Execute the method under test service = ModelProviderService() - with patch.object(service, "get_provider_credential", return_value=expected_credentials) as mock_method: + with patch.object( + service, "get_provider_credential", return_value=expected_credentials, autospec=True + ) as mock_method: result = service.get_provider_credential(tenant.id, "openai") # Assert: Verify the expected outcomes @@ -854,7 +856,9 @@ class TestModelProviderService: # Act: Execute the method under test service = ModelProviderService() - with patch.object(service, "get_model_credential", return_value=expected_credentials) as mock_method: + with patch.object( + service, "get_model_credential", return_value=expected_credentials, autospec=True + ) as mock_method: result = service.get_model_credential(tenant.id, "openai", "llm", "gpt-4", None) # Assert: Verify the expected outcomes diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 934d1bdd34..8f345b9cea 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -22,16 +22,13 @@ class TestWebhookService: def mock_external_dependencies(self): """Mock external service dependencies.""" with ( - patch("services.trigger.webhook_service.AsyncWorkflowService") as mock_async_service, - patch("services.trigger.webhook_service.ToolFileManager") as mock_tool_file_manager, - patch("services.trigger.webhook_service.file_factory") as mock_file_factory, - patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.trigger.webhook_service.AsyncWorkflowService", autospec=True) as mock_async_service, + patch("services.trigger.webhook_service.ToolFileManager", autospec=True) as mock_tool_file_manager, + patch("services.trigger.webhook_service.file_factory", autospec=True) as mock_file_factory, + patch("services.account_service.FeatureService", autospec=True) as mock_feature_service, ): # Mock ToolFileManager - mock_tool_file_instance = MagicMock() - mock_tool_file_manager.return_value = mock_tool_file_instance - - # Mock file creation + mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation mock_tool_file = MagicMock() mock_tool_file.id = "test_file_id" mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file @@ -435,12 +432,12 @@ class TestWebhookService: with flask_app_with_containers.app_context(): # Mock tenant owner lookup to return the test account - with patch("services.trigger.webhook_service.select") as mock_select: + with patch("services.trigger.webhook_service.select", autospec=True) as mock_select: mock_query = MagicMock() mock_select.return_value.join.return_value.where.return_value = mock_query # Mock the session to return our test account - with patch("services.trigger.webhook_service.Session") as mock_session: + with patch("services.trigger.webhook_service.Session", autospec=True) as mock_session: mock_session_instance = MagicMock() mock_session.return_value.__enter__.return_value = mock_session_instance mock_session_instance.scalar.return_value = test_data["account"] @@ -462,7 +459,7 @@ class TestWebhookService: with flask_app_with_containers.app_context(): # Mock EndUserService to raise an exception with patch( - "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type" + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", autospec=True ) as mock_end_user: mock_end_user.side_effect = ValueError("Failed to create end user") diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index ee155021e3..1f91b40963 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,8 +1,8 @@ import pytest from faker import Faker -from core.variables.segments import StringSegment from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable @@ -467,7 +467,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from core.variables.variables import StringVariable + from core.workflow.variables.variables import StringVariable conv_var = StringVariable( id=fake.uuid4(), @@ -650,7 +650,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from core.variables.variables import StringVariable + from core.workflow.variables.variables import StringVariable conv_var1 = StringVariable( id=fake.uuid4(), diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index cb691d5c3d..c29cda9a73 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -764,7 +764,7 @@ class TestWorkflowService: # Act - Mock current_user context and pass session from unittest.mock import patch - with patch("flask_login.utils._get_user", return_value=account): + with patch("flask_login.utils._get_user", return_value=account, autospec=True): result = workflow_service.publish_workflow( session=db_session_with_containers, app_model=app, account=account ) @@ -1391,10 +1391,21 @@ class TestWorkflowService: workflow_service = WorkflowService() + from unittest.mock import patch + + from core.app.workflow.node_factory import DifyNodeFactory + from core.model_manager import ModelInstance + # Act - result = workflow_service.run_free_workflow_node( - node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs - ) + with patch.object( + DifyNodeFactory, + "_build_model_instance_for_llm_node", + return_value=MagicMock(spec=ModelInstance), + autospec=True, + ): + result = workflow_service.run_free_workflow_node( + node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs + ) # Assert assert result is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 088d6ba6ba..8bb536c34a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -18,7 +18,9 @@ class TestAddDocumentToIndexTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.add_document_to_index_task.IndexProcessorFactory") as mock_index_processor_factory, + patch( + "tasks.add_document_to_index_task.IndexProcessorFactory", autospec=True + ) as mock_index_processor_factory, ): # Setup mock index processor mock_processor = MagicMock() @@ -378,7 +380,7 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Mock the get_child_chunks method for each segment - with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks: + with patch.object(DocumentSegment, "get_child_chunks", autospec=True) as mock_get_child_chunks: # Setup mock to return child chunks for each segment mock_child_chunks = [] for i in range(2): # Each segment has 2 child chunks diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 61f6b75b10..2156743c17 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -51,9 +51,9 @@ class TestBatchCreateSegmentToIndexTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.batch_create_segment_to_index_task.storage") as mock_storage, - patch("tasks.batch_create_segment_to_index_task.ModelManager") as mock_model_manager, - patch("tasks.batch_create_segment_to_index_task.VectorService") as mock_vector_service, + patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage, + patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager, + patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service, ): # Setup default mock returns mock_storage.download.return_value = None diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 09407f7686..cd99b2965f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -63,8 +63,8 @@ class TestCleanDatasetTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.clean_dataset_task.storage") as mock_storage, - patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_index_processor_factory, + patch("tasks.clean_dataset_task.storage", autospec=True) as mock_storage, + patch("tasks.clean_dataset_task.IndexProcessorFactory", autospec=True) as mock_index_processor_factory, ): # Setup default mock returns mock_storage.delete.return_value = None @@ -597,7 +597,7 @@ class TestCleanDatasetTask: db_session_with_containers.commit() # Mock the get_image_upload_file_ids function to return our image file IDs - with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids: + with patch("tasks.clean_dataset_task.get_image_upload_file_ids", autospec=True) as mock_get_image_ids: mock_get_image_ids.return_value = [f.id for f in image_files] # Execute the task diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index caa5ee3851..4fa52ff2a9 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -41,7 +41,7 @@ class TestCreateSegmentToIndexTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.create_segment_to_index_task.IndexProcessorFactory") as mock_factory, + patch("tasks.create_segment_to_index_task.IndexProcessorFactory", autospec=True) as mock_factory, ): # Setup default mock returns mock_processor = MagicMock() @@ -708,7 +708,7 @@ class TestCreateSegmentToIndexTask: redis_client.set(cache_key, "processing", ex=300) # Mock Redis to raise exception in finally block - with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed")): + with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed"), autospec=True): # Act: Execute the task - Redis failure should not prevent completion with pytest.raises(Exception) as exc_info: create_segment_to_index_task(segment.id) diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index c3ad18ecec..207bdad751 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -37,7 +37,7 @@ class _TrackedSessionContext: self._closed_sessions.append(self._session) return original_close(*args, **kwargs) - self._close_patcher = patch.object(self._session, "close", side_effect=_tracked_close) + self._close_patcher = patch.object(self._session, "close", side_effect=_tracked_close, autospec=True) self._close_patcher.start() return self._session @@ -69,7 +69,9 @@ def session_close_tracker(): original_context_manager = original_create_session(*args, **kwargs) return _TrackedSessionContext(original_context_manager, opened_sessions, closed_sessions) - with patch.object(task_module.session_factory, "create_session", side_effect=_tracked_create_session): + with patch.object( + task_module.session_factory, "create_session", side_effect=_tracked_create_session, autospec=True + ): yield {"opened_sessions": opened_sessions, "closed_sessions": closed_sessions} @@ -77,13 +79,11 @@ def session_close_tracker(): def patched_external_dependencies(): """Patch non-DB collaborators while keeping database behavior real.""" with ( - patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner, - patch("tasks.document_indexing_task.FeatureService") as mock_feature_service, - patch("tasks.document_indexing_task.generate_summary_index_task") as mock_summary_task, + patch("tasks.document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner, + patch("tasks.document_indexing_task.FeatureService", autospec=True) as mock_feature_service, + patch("tasks.document_indexing_task.generate_summary_index_task", autospec=True) as mock_summary_task, ): - mock_runner_instance = MagicMock() - mock_indexing_runner.return_value = mock_runner_instance - + mock_runner_instance = mock_indexing_runner.return_value mock_features = MagicMock() mock_features.billing.enabled = False mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL @@ -307,9 +307,17 @@ class TestDatasetIndexingTaskIntegration: # Act with ( - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[next_task]), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time") as set_waiting_spy, - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy, + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", + return_value=[next_task], + autospec=True, + ), + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True + ) as set_waiting_spy, + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True + ) as delete_key_spy, ): _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) @@ -336,8 +344,10 @@ class TestDatasetIndexingTaskIntegration: # Act with ( - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy, + patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True), + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True + ) as delete_key_spy, ): _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) @@ -426,9 +436,13 @@ class TestDatasetIndexingTaskIntegration: # Act with ( - patch("tasks.document_indexing_task._document_indexing", side_effect=Exception("failed")), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[next_task]), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time"), + patch("tasks.document_indexing_task._document_indexing", side_effect=Exception("failed"), autospec=True), + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", + return_value=[next_task], + autospec=True, + ), + patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True), ): _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) @@ -511,8 +525,11 @@ class TestDatasetIndexingTaskIntegration: patch( "tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=pending_tasks[:concurrency_limit], + autospec=True, ), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time") as set_waiting_spy, + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True + ) as set_waiting_spy, ): _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) @@ -538,8 +555,12 @@ class TestDatasetIndexingTaskIntegration: # Act with ( patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=ordered_tasks), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time"), + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", + return_value=ordered_tasks, + autospec=True, + ), + patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True), ): _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) @@ -578,8 +599,10 @@ class TestDatasetIndexingTaskIntegration: # Act with ( - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy, + patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True), + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True + ) as delete_key_spy, ): normal_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) @@ -599,8 +622,10 @@ class TestDatasetIndexingTaskIntegration: # Act with ( - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]), - patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy, + patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True), + patch( + "tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True + ) as delete_key_spy, ): priority_document_indexing_task(dataset.tenant_id, dataset.id, document_ids) diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 37d886f569..bc0ed3bd2b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -216,7 +216,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return segments - @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers): """ Test successful segment deletion from index with comprehensive verification. @@ -399,7 +399,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when indexing is not completed - @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) def test_delete_segment_from_index_task_index_processor_clean( self, mock_index_processor_factory, db_session_with_containers ): @@ -457,7 +457,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.reset_mock() mock_processor.reset_mock() - @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) def test_delete_segment_from_index_task_exception_handling( self, mock_index_processor_factory, db_session_with_containers ): @@ -501,7 +501,7 @@ class TestDeleteSegmentFromIndexTask: assert call_args[1]["with_keywords"] is True assert call_args[1]["delete_child_chunks"] is True - @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) def test_delete_segment_from_index_task_empty_index_node_ids( self, mock_index_processor_factory, db_session_with_containers ): @@ -543,7 +543,7 @@ class TestDeleteSegmentFromIndexTask: assert call_args[1]["with_keywords"] is True assert call_args[1]["delete_child_chunks"] is True - @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory") + @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) def test_delete_segment_from_index_task_large_index_node_ids( self, mock_index_processor_factory, db_session_with_containers ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py new file mode 100644 index 0000000000..0b9e29fde9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -0,0 +1,464 @@ +""" +Integration tests for document_indexing_sync_task using testcontainers. + +This module validates SQL-backed behavior for document sync flows: +- Notion sync precondition checks +- Segment cleanup and document state updates +- Credential and indexing error handling +""" + +import json +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from psycopg2.extensions import register_adapter +from psycopg2.extras import Json + +from core.indexing_runner import DocumentIsPausedError, IndexingRunner +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.document_indexing_sync_task import document_indexing_sync_task + + +@pytest.fixture(autouse=True) +def _register_dict_adapter_for_psycopg2(): + """Align test DB adapter behavior with dict payloads used in task update flow.""" + register_adapter(dict, Json) + + +class DocumentIndexingSyncTaskTestDataFactory: + """Create real DB entities for document indexing sync integration tests.""" + + @staticmethod + def create_account_with_tenant(db_session_with_containers) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + tenant = Tenant(name=f"tenant-{account.id}", status="normal") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + return account, tenant + + @staticmethod + def create_dataset(db_session_with_containers, tenant_id: str, created_by: str) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + description="sync test dataset", + data_source_type="notion_import", + indexing_technique="high_quality", + created_by=created_by, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + tenant_id: str, + dataset_id: str, + created_by: str, + data_source_info: dict | None, + indexing_status: str = "completed", + ) -> Document: + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=0, + data_source_type="notion_import", + data_source_info=json.dumps(data_source_info) if data_source_info is not None else None, + batch="test-batch", + name=f"doc-{uuid4()}", + created_from="notion_import", + created_by=created_by, + indexing_status=indexing_status, + enabled=True, + doc_form="text_model", + doc_language="en", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + @staticmethod + def create_segments( + db_session_with_containers, + *, + tenant_id: str, + dataset_id: str, + document_id: str, + created_by: str, + count: int = 3, + ) -> list[DocumentSegment]: + segments: list[DocumentSegment] = [] + for i in range(count): + segment = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + position=i, + content=f"segment-{i}", + answer=None, + word_count=10, + tokens=5, + index_node_id=f"node-{document_id}-{i}", + status="completed", + created_by=created_by, + ) + db_session_with_containers.add(segment) + segments.append(segment) + db_session_with_containers.commit() + return segments + + +class TestDocumentIndexingSyncTask: + """Integration tests for document_indexing_sync_task with real database assertions.""" + + @pytest.fixture + def mock_external_dependencies(self): + """Patch only external collaborators; keep DB access real.""" + with ( + patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_datasource_service_class, + patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_notion_extractor_class, + patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_index_processor_factory, + patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_indexing_runner_class, + ): + datasource_service = Mock() + datasource_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"} + mock_datasource_service_class.return_value = datasource_service + + notion_extractor = Mock() + notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" + mock_notion_extractor_class.return_value = notion_extractor + + index_processor = Mock() + index_processor.clean = Mock() + mock_index_processor_factory.return_value.init_index_processor.return_value = index_processor + + indexing_runner = Mock(spec=IndexingRunner) + indexing_runner.run = Mock() + mock_indexing_runner_class.return_value = indexing_runner + + yield { + "datasource_service": datasource_service, + "notion_extractor": notion_extractor, + "notion_extractor_class": mock_notion_extractor_class, + "index_processor": index_processor, + "index_processor_factory": mock_index_processor_factory, + "indexing_runner": indexing_runner, + } + + def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None): + account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=account.id, + ) + + notion_info = data_source_info or { + "notion_workspace_id": str(uuid4()), + "notion_page_id": str(uuid4()), + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + "credential_id": str(uuid4()), + } + + document = DocumentIndexingSyncTaskTestDataFactory.create_document( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + created_by=account.id, + data_source_info=notion_info, + indexing_status="completed", + ) + + segments = DocumentIndexingSyncTaskTestDataFactory.create_segments( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + created_by=account.id, + count=3, + ) + + return { + "account": account, + "tenant": tenant, + "dataset": dataset, + "document": document, + "segments": segments, + "node_ids": [segment.index_node_id for segment in segments], + "notion_info": notion_info, + } + + def test_document_not_found(self, db_session_with_containers, mock_external_dependencies): + """Test that task handles missing document gracefully.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + + # Act + document_indexing_sync_task(dataset_id, document_id) + + # Assert + mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called() + mock_external_dependencies["indexing_runner"].run.assert_not_called() + + def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies): + """Test that task raises error when notion_workspace_id is missing.""" + # Arrange + context = self._create_notion_sync_context( + db_session_with_containers, + data_source_info={ + "notion_page_id": str(uuid4()), + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + }, + ) + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies): + """Test that task raises error when notion_page_id is missing.""" + # Arrange + context = self._create_notion_sync_context( + db_session_with_containers, + data_source_info={ + "notion_workspace_id": str(uuid4()), + "type": "page", + "last_edited_time": "2024-01-01T00:00:00Z", + }, + ) + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies): + """Test that task raises error when data_source_info is empty.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) + db_session_with_containers.query(Document).where(Document.id == context["document"].id).update( + {"data_source_info": None} + ) + db_session_with_containers.commit() + + # Act & Assert + with pytest.raises(ValueError, match="no notion page found"): + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies): + """Test that task sets document error state when credential is missing.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers) + mock_external_dependencies["datasource_service"].get_datasource_credentials.return_value = None + + # Act + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + # Assert + db_session_with_containers.expire_all() + updated_document = ( + db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + ) + assert updated_document is not None + assert updated_document.indexing_status == "error" + assert "Datasource credential not found" in updated_document.error + assert updated_document.stopped_at is not None + mock_external_dependencies["indexing_runner"].run.assert_not_called() + + def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies): + """Test that task exits early when notion page is unchanged.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers) + mock_external_dependencies["notion_extractor"].get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + + # Act + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + # Assert + db_session_with_containers.expire_all() + updated_document = ( + db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + ) + remaining_segments = ( + db_session_with_containers.query(DocumentSegment) + .where(DocumentSegment.document_id == context["document"].id) + .count() + ) + assert updated_document is not None + assert updated_document.indexing_status == "completed" + assert updated_document.processing_started_at is None + assert remaining_segments == 3 + mock_external_dependencies["index_processor"].clean.assert_not_called() + mock_external_dependencies["indexing_runner"].run.assert_not_called() + + def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies): + """Test full successful sync flow with SQL state updates and side effects.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers) + + # Act + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + # Assert + db_session_with_containers.expire_all() + updated_document = ( + db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + ) + remaining_segments = ( + db_session_with_containers.query(DocumentSegment) + .where(DocumentSegment.document_id == context["document"].id) + .count() + ) + + assert updated_document is not None + assert updated_document.indexing_status == "parsing" + assert updated_document.processing_started_at is not None + assert updated_document.data_source_info_dict.get("last_edited_time") == "2024-01-02T00:00:00Z" + assert remaining_segments == 0 + + clean_call_args = mock_external_dependencies["index_processor"].clean.call_args + assert clean_call_args is not None + clean_args, clean_kwargs = clean_call_args + assert getattr(clean_args[0], "id", None) == context["dataset"].id + assert set(clean_args[1]) == set(context["node_ids"]) + assert clean_kwargs.get("with_keywords") is True + assert clean_kwargs.get("delete_child_chunks") is True + + run_call_args = mock_external_dependencies["indexing_runner"].run.call_args + assert run_call_args is not None + run_documents = run_call_args[0][0] + assert len(run_documents) == 1 + assert getattr(run_documents[0], "id", None) == context["document"].id + + def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies): + """Test that task still updates document and reindexes if dataset vanishes before clean.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers) + + def _delete_dataset_before_clean() -> str: + db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete() + db_session_with_containers.commit() + return "2024-01-02T00:00:00Z" + + mock_external_dependencies[ + "notion_extractor" + ].get_notion_last_edited_time.side_effect = _delete_dataset_before_clean + + # Act + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + # Assert + db_session_with_containers.expire_all() + updated_document = ( + db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + ) + assert updated_document is not None + assert updated_document.indexing_status == "parsing" + mock_external_dependencies["index_processor"].clean.assert_not_called() + mock_external_dependencies["indexing_runner"].run.assert_called_once() + + def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies): + """Test that indexing continues when index cleanup fails.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers) + mock_external_dependencies["index_processor"].clean.side_effect = Exception("Cleaning error") + + # Act + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + # Assert + db_session_with_containers.expire_all() + updated_document = ( + db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + ) + remaining_segments = ( + db_session_with_containers.query(DocumentSegment) + .where(DocumentSegment.document_id == context["document"].id) + .count() + ) + assert updated_document is not None + assert updated_document.indexing_status == "parsing" + assert remaining_segments == 0 + mock_external_dependencies["indexing_runner"].run.assert_called_once() + + def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies): + """Test that DocumentIsPausedError does not flip document into error state.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers) + mock_external_dependencies["indexing_runner"].run.side_effect = DocumentIsPausedError("Document paused") + + # Act + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + # Assert + db_session_with_containers.expire_all() + updated_document = ( + db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + ) + assert updated_document is not None + assert updated_document.indexing_status == "parsing" + assert updated_document.error is None + + def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies): + """Test that indexing errors are persisted to document state.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers) + mock_external_dependencies["indexing_runner"].run.side_effect = Exception("Indexing error") + + # Act + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + # Assert + db_session_with_containers.expire_all() + updated_document = ( + db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + ) + assert updated_document is not None + assert updated_document.indexing_status == "error" + assert "Indexing error" in updated_document.error + assert updated_document.stopped_at is not None + + def test_index_processor_clean_called_with_correct_params( + self, + db_session_with_containers, + mock_external_dependencies, + ): + """Test that clean is called with dataset instance and collected node ids.""" + # Arrange + context = self._create_notion_sync_context(db_session_with_containers) + + # Act + document_indexing_sync_task(context["dataset"].id, context["document"].id) + + # Assert + clean_call_args = mock_external_dependencies["index_processor"].clean.call_args + assert clean_call_args is not None + clean_args, clean_kwargs = clean_call_args + assert getattr(clean_args[0], "id", None) == context["dataset"].id + assert set(clean_args[1]) == set(context["node_ids"]) + assert clean_kwargs.get("with_keywords") is True + assert clean_kwargs.get("delete_child_chunks") is True diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 0d266e7e76..4be1180c73 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -32,14 +32,11 @@ class TestDocumentIndexingTasks: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner, - patch("tasks.document_indexing_task.FeatureService") as mock_feature_service, + patch("tasks.document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner, + patch("tasks.document_indexing_task.FeatureService", autospec=True) as mock_feature_service, ): # Setup mock indexing runner - mock_runner_instance = MagicMock() - mock_indexing_runner.return_value = mock_runner_instance - - # Setup mock feature service + mock_runner_instance = mock_indexing_runner.return_value # Setup mock feature service mock_features = MagicMock() mock_features.billing.enabled = False mock_feature_service.get_features.return_value = mock_features diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 7f37f84113..9da9a4132e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -16,15 +16,13 @@ class TestDocumentIndexingUpdateTask: - IndexingRunner.run([...]) """ with ( - patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory, - patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner, + patch("tasks.document_indexing_update_task.IndexProcessorFactory", autospec=True) as mock_factory, + patch("tasks.document_indexing_update_task.IndexingRunner", autospec=True) as mock_runner, ): processor_instance = MagicMock() mock_factory.return_value.init_index_processor.return_value = processor_instance - runner_instance = MagicMock() - mock_runner.return_value = runner_instance - + runner_instance = mock_runner.return_value yield { "factory": mock_factory, "processor": processor_instance, diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index fbcee899e1..b2e1ce3b89 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -31,15 +31,14 @@ class TestDuplicateDocumentIndexingTasks: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner, - patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service, - patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory, + patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner, + patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_feature_service, + patch( + "tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True + ) as mock_index_processor_factory, ): # Setup mock indexing runner - mock_runner_instance = MagicMock() - mock_indexing_runner.return_value = mock_runner_instance - - # Setup mock feature service + mock_runner_instance = mock_indexing_runner.return_value # Setup mock feature service mock_features = MagicMock() mock_features.billing.enabled = False mock_feature_service.get_features.return_value = mock_features @@ -650,7 +649,7 @@ class TestDuplicateDocumentIndexingTasks: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" - @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_normal_duplicate_document_indexing_task_with_tenant_queue( self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies ): @@ -693,7 +692,7 @@ class TestDuplicateDocumentIndexingTasks: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" - @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_priority_duplicate_document_indexing_task_with_tenant_queue( self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies ): @@ -737,7 +736,7 @@ class TestDuplicateDocumentIndexingTasks: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" - @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") + @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index b738646736..b3d9e49b30 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -18,7 +18,9 @@ class TestEnableSegmentsToIndexTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.enable_segments_to_index_task.IndexProcessorFactory") as mock_index_processor_factory, + patch( + "tasks.enable_segments_to_index_task.IndexProcessorFactory", autospec=True + ) as mock_index_processor_factory, ): # Setup mock index processor mock_processor = MagicMock() @@ -370,7 +372,7 @@ class TestEnableSegmentsToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Mock the get_child_chunks method for each segment - with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks: + with patch.object(DocumentSegment, "get_child_chunks", autospec=True) as mock_get_child_chunks: # Setup mock to return child chunks for each segment mock_child_chunks = [] for i in range(2): # Each segment has 2 child chunks diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py index 31e9b67421..6c3a9ef20a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from faker import Faker @@ -16,16 +16,14 @@ class TestMailAccountDeletionTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.mail_account_deletion_task.mail") as mock_mail, - patch("tasks.mail_account_deletion_task.get_email_i18n_service") as mock_get_email_service, + patch("tasks.mail_account_deletion_task.mail", autospec=True) as mock_mail, + patch("tasks.mail_account_deletion_task.get_email_i18n_service", autospec=True) as mock_get_email_service, ): # Setup mock mail service mock_mail.is_inited.return_value = True # Setup mock email service - mock_email_service = MagicMock() - mock_get_email_service.return_value = mock_email_service - + mock_email_service = mock_get_email_service.return_value yield { "mail": mock_mail, "get_email_service": mock_get_email_service, diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py index 1aed7dc7cc..177af266fb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from faker import Faker @@ -15,16 +15,14 @@ class TestMailChangeMailTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.mail_change_mail_task.mail") as mock_mail, - patch("tasks.mail_change_mail_task.get_email_i18n_service") as mock_get_email_i18n_service, + patch("tasks.mail_change_mail_task.mail", autospec=True) as mock_mail, + patch("tasks.mail_change_mail_task.get_email_i18n_service", autospec=True) as mock_get_email_i18n_service, ): # Setup mock mail service mock_mail.is_inited.return_value = True # Setup mock email i18n service - mock_email_service = MagicMock() - mock_get_email_i18n_service.return_value = mock_email_service - + mock_email_service = mock_get_email_i18n_service.return_value yield { "mail": mock_mail, "email_i18n_service": mock_email_service, diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index e6a804784a..3cdec70df7 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -53,8 +53,8 @@ class TestSendEmailCodeLoginMailTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.mail_email_code_login.mail") as mock_mail, - patch("tasks.mail_email_code_login.get_email_i18n_service") as mock_email_service, + patch("tasks.mail_email_code_login.mail", autospec=True) as mock_mail, + patch("tasks.mail_email_code_login.get_email_i18n_service", autospec=True) as mock_email_service, ): # Setup default mock returns mock_mail.is_inited.return_value = True @@ -573,7 +573,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.side_effect = exception # Mock logging to capture error messages - with patch("tasks.mail_email_code_login.logger") as mock_logger: + with patch("tasks.mail_email_code_login.logger", autospec=True) as mock_logger: # Act: Execute the task - it should handle the exception gracefully send_email_code_login_mail_task( language=test_language, diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py index d67794654f..1a20b6deec 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from faker import Faker @@ -13,18 +13,15 @@ class TestMailInnerTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.mail_inner_task.mail") as mock_mail, - patch("tasks.mail_inner_task.get_email_i18n_service") as mock_get_email_i18n_service, - patch("tasks.mail_inner_task._render_template_with_strategy") as mock_render_template, + patch("tasks.mail_inner_task.mail", autospec=True) as mock_mail, + patch("tasks.mail_inner_task.get_email_i18n_service", autospec=True) as mock_get_email_i18n_service, + patch("tasks.mail_inner_task._render_template_with_strategy", autospec=True) as mock_render_template, ): # Setup mock mail service mock_mail.is_inited.return_value = True # Setup mock email i18n service - mock_email_service = MagicMock() - mock_get_email_i18n_service.return_value = mock_email_service - - # Setup mock template rendering + mock_email_service = mock_get_email_i18n_service.return_value # Setup mock template rendering mock_render_template.return_value = "Test email content" yield { diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index c083861004..212fbd26cd 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -56,9 +56,9 @@ class TestMailInviteMemberTask: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("tasks.mail_invite_member_task.mail") as mock_mail, - patch("tasks.mail_invite_member_task.get_email_i18n_service") as mock_email_service, - patch("tasks.mail_invite_member_task.dify_config") as mock_config, + patch("tasks.mail_invite_member_task.mail", autospec=True) as mock_mail, + patch("tasks.mail_invite_member_task.get_email_i18n_service", autospec=True) as mock_email_service, + patch("tasks.mail_invite_member_task.dify_config", autospec=True) as mock_config, ): # Setup mail service mock mock_mail.is_inited.return_value = True @@ -306,7 +306,7 @@ class TestMailInviteMemberTask: mock_email_service.send_email.side_effect = Exception("Email service failed") # Act & Assert: Execute task and verify exception is handled - with patch("tasks.mail_invite_member_task.logger") as mock_logger: + with patch("tasks.mail_invite_member_task.logger", autospec=True) as mock_logger: send_invite_member_mail_task( language="en-US", to="test@example.com", diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py index e128b06b11..e08b099480 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py @@ -7,7 +7,7 @@ testing with actual database and service dependencies. """ import logging -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from faker import Faker @@ -30,16 +30,14 @@ class TestMailOwnerTransferTask: def mock_mail_dependencies(self): """Mock setup for mail service dependencies.""" with ( - patch("tasks.mail_owner_transfer_task.mail") as mock_mail, - patch("tasks.mail_owner_transfer_task.get_email_i18n_service") as mock_get_email_service, + patch("tasks.mail_owner_transfer_task.mail", autospec=True) as mock_mail, + patch("tasks.mail_owner_transfer_task.get_email_i18n_service", autospec=True) as mock_get_email_service, ): # Setup mock mail service mock_mail.is_inited.return_value = True # Setup mock email service - mock_email_service = MagicMock() - mock_get_email_service.return_value = mock_email_service - + mock_email_service = mock_get_email_service.return_value yield { "mail": mock_mail, "email_service": mock_email_service, diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py index e4db14623d..cced6f7780 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py @@ -5,7 +5,7 @@ This module provides integration tests for email registration tasks using TestContainers to ensure real database and service interactions. """ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from faker import Faker @@ -21,16 +21,14 @@ class TestMailRegisterTask: def mock_mail_dependencies(self): """Mock setup for mail service dependencies.""" with ( - patch("tasks.mail_register_task.mail") as mock_mail, - patch("tasks.mail_register_task.get_email_i18n_service") as mock_get_email_service, + patch("tasks.mail_register_task.mail", autospec=True) as mock_mail, + patch("tasks.mail_register_task.get_email_i18n_service", autospec=True) as mock_get_email_service, ): # Setup mock mail service mock_mail.is_inited.return_value = True # Setup mock email i18n service - mock_email_service = MagicMock() - mock_get_email_service.return_value = mock_email_service - + mock_email_service = mock_get_email_service.return_value yield { "mail": mock_mail, "email_service": mock_email_service, @@ -76,7 +74,7 @@ class TestMailRegisterTask: to_email = fake.email() code = fake.numerify("######") - with patch("tasks.mail_register_task.logger") as mock_logger: + with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger: send_email_register_mail_task(language="en-US", to=to_email, code=code) mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) @@ -89,7 +87,7 @@ class TestMailRegisterTask: to_email = fake.email() account_name = fake.name() - with patch("tasks.mail_register_task.dify_config") as mock_config: + with patch("tasks.mail_register_task.dify_config", autospec=True) as mock_config: mock_config.CONSOLE_WEB_URL = "https://console.dify.ai" send_email_register_mail_task_when_account_exist(language=language, to=to_email, account_name=account_name) @@ -129,6 +127,6 @@ class TestMailRegisterTask: to_email = fake.email() account_name = fake.name() - with patch("tasks.mail_register_task.logger") as mock_logger: + with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger: send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name) mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 7ac9573ab7..8501a8e39b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -4,8 +4,8 @@ from unittest.mock import ANY, call, patch import pytest from core.db.session_factory import session_factory -from core.variables.segments import StringSegment -from core.variables.types import SegmentType +from core.workflow.variables.segments import StringSegment +from core.workflow.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py index 7bab73d6c6..460da06ecc 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -12,9 +12,17 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(): conversation.id = "conversation-id" with ( - patch("controllers.console.app.conversation.current_account_with_tenant", return_value=(account, None)), - patch("controllers.console.app.conversation.naive_utc_now", return_value=datetime(2026, 2, 9, 0, 0, 0)), - patch("controllers.console.app.conversation.db.session") as mock_session, + patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, None), + autospec=True, + ), + patch( + "controllers.console.app.conversation.naive_utc_now", + return_value=datetime(2026, 2, 9, 0, 0, 0), + autospec=True, + ), + patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, ): mock_session.query.return_value.where.return_value.first.return_value = conversation diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index ec35366d02..cf10182ad3 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -13,8 +13,8 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variables.types import SegmentType from factories.variable_factory import build_segment from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -40,7 +40,7 @@ class TestWorkflowDraftVariableFields: mock_variable.variable_file = mock_variable_file # Mock the file helpers - with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers: + with patch("controllers.console.app.workflow_draft_variable.file_helpers", autospec=True) as mock_file_helpers: mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url" # Call the function @@ -203,7 +203,7 @@ class TestWorkflowDraftVariableFields: } ) - with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers: + with patch("controllers.console.app.workflow_draft_variable.file_helpers", autospec=True) as mock_file_helpers: mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url" assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value expected_with_value = expected_without_value.copy() diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index 8da930b7fa..d010f60866 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -47,8 +47,8 @@ class TestRefreshTokenApi: token_pair.csrf_token = "new_csrf_token" return token_pair - @patch("controllers.console.auth.login.extract_refresh_token") - @patch("controllers.console.auth.login.AccountService.refresh_token") + @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) + @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app, mock_token_pair): """ Test successful token refresh flow. @@ -73,7 +73,7 @@ class TestRefreshTokenApi: mock_refresh_token.assert_called_once_with("valid_refresh_token") assert response.json["result"] == "success" - @patch("controllers.console.auth.login.extract_refresh_token") + @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) def test_refresh_fails_without_token(self, mock_extract_token, app): """ Test token refresh failure when no refresh token provided. @@ -96,8 +96,8 @@ class TestRefreshTokenApi: assert response["result"] == "fail" assert "No refresh token provided" in response["message"] - @patch("controllers.console.auth.login.extract_refresh_token") - @patch("controllers.console.auth.login.AccountService.refresh_token") + @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) + @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app): """ Test token refresh failure with invalid refresh token. @@ -121,8 +121,8 @@ class TestRefreshTokenApi: assert response["result"] == "fail" assert "Invalid refresh token" in response["message"] - @patch("controllers.console.auth.login.extract_refresh_token") - @patch("controllers.console.auth.login.AccountService.refresh_token") + @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) + @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app): """ Test token refresh failure with expired refresh token. @@ -146,8 +146,8 @@ class TestRefreshTokenApi: assert response["result"] == "fail" assert "expired" in response["message"].lower() - @patch("controllers.console.auth.login.extract_refresh_token") - @patch("controllers.console.auth.login.AccountService.refresh_token") + @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) + @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app): """ Test token refresh with empty string token. @@ -168,8 +168,8 @@ class TestRefreshTokenApi: assert status_code == 401 assert response["result"] == "fail" - @patch("controllers.console.auth.login.extract_refresh_token") - @patch("controllers.console.auth.login.AccountService.refresh_token") + @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) + @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app, mock_token_pair): """ Test that token refresh updates all three tokens. diff --git a/api/tests/unit_tests/controllers/console/test_extension.py b/api/tests/unit_tests/controllers/console/test_extension.py index 32b41baa27..85eb6e7d71 100644 --- a/api/tests/unit_tests/controllers/console/test_extension.py +++ b/api/tests/unit_tests/controllers/console/test_extension.py @@ -77,7 +77,7 @@ def _restx_mask_defaults(app: Flask): def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch): - service_result = {"entrypoint": "main:agent"} + service_result = [{"entrypoint": "main:agent"}] service_mock = MagicMock(return_value=service_result) monkeypatch.setattr( "controllers.console.extension.CodeBasedExtensionService.get_code_based_extension", diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py index c608f731c5..b15676d9b7 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py @@ -39,10 +39,12 @@ def client(): @patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1") + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "t1"), + autospec=True, ) -@patch("controllers.console.workspace.tool_providers.Session") -@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url") +@patch("controllers.console.workspace.tool_providers.Session", autospec=True) +@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True) @pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client): # Arrange: reconnect returns tools immediately @@ -62,7 +64,7 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_ svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path mock_session.return_value.__enter__.return_value = MagicMock() # Patch MCPToolManageService constructed inside controller - with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc): + with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc, autospec=True): payload = { "server_url": "http://example.com/mcp", "name": "demo", @@ -77,12 +79,19 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_ # Act with ( patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check - patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")), - patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required - patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login + patch( + "controllers.console.wraps.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "t1"), + autospec=True, + ), + patch("libs.login.check_csrf_token", return_value=None, autospec=True), # bypass CSRF in login_required + patch( + "libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True), autospec=True + ), # login patch( "services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider", return_value={"id": "provider-1", "tools": [{"name": "ping"}]}, + autospec=True, ), ): resp = client.post( diff --git a/api/tests/unit_tests/controllers/mcp/test_mcp.py b/api/tests/unit_tests/controllers/mcp/test_mcp.py index 862d611087..b93770e9c2 100644 --- a/api/tests/unit_tests/controllers/mcp/test_mcp.py +++ b/api/tests/unit_tests/controllers/mcp/test_mcp.py @@ -77,7 +77,7 @@ class DummyResult: class TestMCPAppApi: - @patch.object(module, "handle_mcp_request", return_value=DummyResult()) + @patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True) def test_success_request(self, mock_handle): fake_payload( { @@ -321,7 +321,7 @@ class TestMCPAppApi: post_fn("server-1") assert "App is unavailable" in str(exc_info.value) - @patch.object(module, "handle_mcp_request", return_value=None) + @patch.object(module, "handle_mcp_request", return_value=None, autospec=True) def test_mcp_request_no_response(self, mock_handle): """Test when handle_mcp_request returns None""" fake_payload( @@ -380,7 +380,7 @@ class TestMCPAppApi: api = module.MCPAppApi() api._get_mcp_server_and_app = MagicMock(return_value=(server, app)) - with patch.object(module, "handle_mcp_request", return_value=DummyResult()): + with patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True): post_fn = unwrap(api.post) response = post_fn("server-1") assert isinstance(response, Response) @@ -409,7 +409,7 @@ class TestMCPAppApi: api = module.MCPAppApi() api._get_mcp_server_and_app = MagicMock(return_value=(server, app)) - with patch.object(module, "handle_mcp_request", return_value=DummyResult()): + with patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True): post_fn = unwrap(api.post) response = post_fn("server-1") assert isinstance(response, Response) diff --git a/api/tests/unit_tests/controllers/service_api/test_index.py b/api/tests/unit_tests/controllers/service_api/test_index.py index ae484448a9..c560a3c698 100644 --- a/api/tests/unit_tests/controllers/service_api/test_index.py +++ b/api/tests/unit_tests/controllers/service_api/test_index.py @@ -12,7 +12,7 @@ from controllers.service_api.index import IndexApi class TestIndexApi: """Test suite for IndexApi resource.""" - @patch("controllers.service_api.index.dify_config") + @patch("controllers.service_api.index.dify_config", autospec=True) def test_get_returns_api_info(self, mock_config, app): """Test that GET returns API metadata with correct structure.""" # Arrange diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 3a4fdc3cd8..0ca54a2f4a 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.variables import SegmentType +from core.workflow.variables import SegmentType from factories import variable_factory from models import ConversationVariable, Workflow diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index 0bbfd452e1..1931e230b2 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -71,17 +71,17 @@ class TestBaseAppRunnerMultimodal: mime_type="image/png", ) - with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_url.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr - with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session") as mock_session: + with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: mock_session.add = MagicMock() mock_session.commit = MagicMock() mock_session.refresh = MagicMock() @@ -158,17 +158,17 @@ class TestBaseAppRunnerMultimodal: mime_type="image/png", ) - with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_raw.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr - with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session") as mock_session: + with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: mock_session.add = MagicMock() mock_session.commit = MagicMock() mock_session.refresh = MagicMock() @@ -231,17 +231,17 @@ class TestBaseAppRunnerMultimodal: mime_type="image/png", ) - with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_raw.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr - with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session") as mock_session: + with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: mock_session.add = MagicMock() mock_session.commit = MagicMock() mock_session.refresh = MagicMock() @@ -282,9 +282,9 @@ class TestBaseAppRunnerMultimodal: mime_type="image/png", ) - with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: - with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: - with patch("core.app.apps.base_app_runner.db.session") as mock_session: + with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: + with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: # Act # Create a mock runner with the method bound runner = MagicMock() @@ -321,14 +321,14 @@ class TestBaseAppRunnerMultimodal: mime_type="image/png", ) - with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: # Setup mock to raise exception mock_mgr = MagicMock() mock_mgr.create_file_by_url.side_effect = Exception("Network error") mock_mgr_class.return_value = mock_mgr - with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: - with patch("core.app.apps.base_app_runner.db.session") as mock_session: + with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: # Act # Create a mock runner with the method bound runner = MagicMock() @@ -368,17 +368,17 @@ class TestBaseAppRunnerMultimodal: ) mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER - with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_url.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr - with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session") as mock_session: + with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: mock_session.add = MagicMock() mock_session.commit = MagicMock() mock_session.refresh = MagicMock() @@ -420,17 +420,17 @@ class TestBaseAppRunnerMultimodal: ) mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API - with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class: + with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_url.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr - with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class: + with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session") as mock_session: + with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: mock_session.add = MagicMock() mock_session.commit = MagicMock() mock_session.refresh = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index f252324a85..5508a117c1 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,8 +1,8 @@ from collections.abc import Mapping, Sequence from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from core.workflow.variables.segments import ArrayFileSegment, FileSegment class TestWorkflowResponseConverterFetchFilesFromVariableValue: diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/conftest.py b/api/tests/unit_tests/core/app/features/rate_limiting/conftest.py index 9557e78150..9e750bd595 100644 --- a/api/tests/unit_tests/core/app/features/rate_limiting/conftest.py +++ b/api/tests/unit_tests/core/app/features/rate_limiting/conftest.py @@ -84,7 +84,7 @@ def mock_time(): mock_time_val += seconds return mock_time_val - with patch("time.time", return_value=mock_time_val) as mock: + with patch("time.time", return_value=mock_time_val, autospec=True) as mock: mock.increment = increment_time yield mock diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index b6e8cc9c8e..d3ae577d0d 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -3,8 +3,6 @@ from datetime import datetime from unittest.mock import Mock from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.variables import StringVariable -from core.variables.segments import Segment from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph_engine.protocols.command_channel import CommandChannel @@ -13,6 +11,8 @@ from core.workflow.node_events import NodeRunResult from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState from core.workflow.system_variable import SystemVariable +from core.workflow.variables import StringVariable +from core.workflow.variables.segments import Segment class MockReadOnlyVariablePool: diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 1d885f6b2e..539f0cb581 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -13,7 +13,6 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) -from core.variables.segments import Segment from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph_engine.entities.commands import GraphEngineCommand from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -24,6 +23,7 @@ from core.workflow.graph_events.graph import ( GraphRunSucceededEvent, ) from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool +from core.workflow.variables.segments import Segment from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py new file mode 100644 index 0000000000..9ee1df8bdc --- /dev/null +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -0,0 +1,135 @@ +import types +from collections.abc import Generator + +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceMessage +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent + + +def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, + message=DatasourceMessage.TextMessage(text=text), + meta=None, + ) + + +def test_get_icon_url_calls_runtime(mocker): + fake_runtime = mocker.Mock() + fake_runtime.get_icon_url.return_value = "https://icon" + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime) + + url = DatasourceManager.get_icon_url( + provider_id="p/x", + tenant_id="t1", + datasource_name="ds", + datasource_type="online_document", + ) + assert url == "https://icon" + DatasourceManager.get_datasource_runtime.assert_called_once() + + +def test_stream_online_results_yields_messages_online_document(mocker): + # stub runtime to yield a text message + def _doc_messages(**_): + yield from _gen_messages_text_only("hello") + + fake_runtime = mocker.Mock() + fake_runtime.get_online_document_page_content.side_effect = _doc_messages + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value=None, + ) + + gen = DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + msgs = list(gen) + assert len(msgs) == 1 + assert msgs[0].message.text == "hello" + + +def test_stream_node_events_emits_events_online_document(mocker): + # make manager's low-level stream produce TEXT only + mocker.patch.object( + DatasourceManager, + "stream_online_results", + return_value=_gen_messages_text_only("hello"), + ) + + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={"k": "v"}, + datasource_info={"user_id": "u1"}, + variable_pool=mocker.Mock(), + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + ) + # should contain one StreamChunkEvent then a final chunk (empty) and a completed event + assert isinstance(events[0], StreamChunkEvent) + assert events[0].chunk == "hello" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + +def test_get_upload_file_by_id_builds_file(mocker): + # fake UploadFile row + fake_row = types.SimpleNamespace( + id="fid", + name="f", + extension="txt", + mime_type="text/plain", + size=1, + key="k", + source_url="http://x", + ) + + class _Q: + def __init__(self, row): + self._row = row + + def where(self, *_args, **_kwargs): + return self + + def first(self): + return self._row + + class _S: + def __init__(self, row): + self._row = row + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def query(self, *_): + return _Q(self._row) + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row)) + + f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") + assert f.related_id == "fid" + assert f.extension == ".txt" diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index d6d75fb72f..3b5c5e6597 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -9,7 +9,7 @@ from core.helper.ssrf_proxy import ( ) -@patch("core.helper.ssrf_proxy._get_ssrf_client") +@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True) def test_successful_request(mock_get_client): mock_client = MagicMock() mock_response = MagicMock() @@ -22,7 +22,7 @@ def test_successful_request(mock_get_client): mock_client.request.assert_called_once() -@patch("core.helper.ssrf_proxy._get_ssrf_client") +@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True) def test_retry_exceed_max_retries(mock_get_client): mock_client = MagicMock() mock_response = MagicMock() @@ -71,7 +71,7 @@ class TestGetUserProvidedHostHeader: assert result in ("first.com", "second.com") -@patch("core.helper.ssrf_proxy._get_ssrf_client") +@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True) def test_host_header_preservation_with_user_header(mock_get_client): """Test that user-provided Host header is preserved in the request.""" mock_client = MagicMock() @@ -89,7 +89,7 @@ def test_host_header_preservation_with_user_header(mock_get_client): assert call_kwargs["headers"]["host"] == custom_host -@patch("core.helper.ssrf_proxy._get_ssrf_client") +@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True) @pytest.mark.parametrize("host_key", ["host", "HOST", "Host"]) def test_host_header_preservation_case_insensitive(mock_get_client, host_key): """Test that Host header is preserved regardless of case.""" @@ -113,7 +113,7 @@ class TestFollowRedirectsParameter: These tests verify that follow_redirects is correctly passed to client.request(). """ - @patch("core.helper.ssrf_proxy._get_ssrf_client") + @patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True) def test_follow_redirects_passed_to_request(self, mock_get_client): """Verify follow_redirects IS passed to client.request().""" mock_client = MagicMock() @@ -128,7 +128,7 @@ class TestFollowRedirectsParameter: call_kwargs = mock_client.request.call_args.kwargs assert call_kwargs.get("follow_redirects") is True - @patch("core.helper.ssrf_proxy._get_ssrf_client") + @patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True) def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client): """Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style).""" mock_client = MagicMock() @@ -145,7 +145,7 @@ class TestFollowRedirectsParameter: assert call_kwargs.get("follow_redirects") is True assert "allow_redirects" not in call_kwargs - @patch("core.helper.ssrf_proxy._get_ssrf_client") + @patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True) def test_follow_redirects_not_set_when_not_specified(self, mock_get_client): """Verify follow_redirects is not in kwargs when not specified (httpx default behavior).""" mock_client = MagicMock() @@ -160,7 +160,7 @@ class TestFollowRedirectsParameter: call_kwargs = mock_client.request.call_args.kwargs assert "follow_redirects" not in call_kwargs - @patch("core.helper.ssrf_proxy._get_ssrf_client") + @patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True) def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client): """Verify follow_redirects takes precedence when both are specified.""" mock_client = MagicMock() diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py index b66ad111d5..7c2767266f 100644 --- a/api/tests/unit_tests/core/logging/test_filters.py +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -72,7 +72,7 @@ class TestTraceContextFilter: mock_span.get_span_context.return_value = mock_context with ( - mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True), mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), ): @@ -108,7 +108,9 @@ class TestIdentityContextFilter: filter = IdentityContextFilter() # Should not raise even if something goes wrong - with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")): + with mock.patch( + "core.logging.filters.flask.has_request_context", side_effect=Exception("Test error"), autospec=True + ): result = filter.filter(log_record) assert result is True assert log_record.tenant_id == "" diff --git a/api/tests/unit_tests/core/logging/test_trace_helpers.py b/api/tests/unit_tests/core/logging/test_trace_helpers.py index aab1753b9b..1b44553bff 100644 --- a/api/tests/unit_tests/core/logging/test_trace_helpers.py +++ b/api/tests/unit_tests/core/logging/test_trace_helpers.py @@ -8,7 +8,7 @@ class TestGetSpanIdFromOtelContext: def test_returns_none_without_span(self): from core.helper.trace_id_helper import get_span_id_from_otel_context - with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True): result = get_span_id_from_otel_context() assert result is None @@ -20,7 +20,7 @@ class TestGetSpanIdFromOtelContext: mock_context.span_id = 0x051581BF3BB55C45 mock_span.get_span_context.return_value = mock_context - with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True): with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0): result = get_span_id_from_otel_context() assert result == "051581bf3bb55c45" @@ -28,7 +28,7 @@ class TestGetSpanIdFromOtelContext: def test_returns_none_on_exception(self): from core.helper.trace_id_helper import get_span_id_from_otel_context - with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")): + with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error"), autospec=True): result = get_span_id_from_otel_context() assert result is None @@ -37,7 +37,7 @@ class TestGenerateTraceparentHeader: def test_generates_valid_format(self): from core.helper.trace_id_helper import generate_traceparent_header - with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True): result = generate_traceparent_header() assert result is not None @@ -58,7 +58,7 @@ class TestGenerateTraceparentHeader: mock_context.span_id = 0x051581BF3BB55C45 mock_span.get_span_context.return_value = mock_context - with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span): + with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True): with ( mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), @@ -70,7 +70,7 @@ class TestGenerateTraceparentHeader: def test_generates_hex_only_values(self): from core.helper.trace_id_helper import generate_traceparent_header - with mock.patch("opentelemetry.trace.get_current_span", return_value=None): + with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True): result = generate_traceparent_header() parts = result.split("-") diff --git a/api/tests/unit_tests/core/mcp/test_utils.py b/api/tests/unit_tests/core/mcp/test_utils.py index ca41d5f4c1..5ef2f703cd 100644 --- a/api/tests/unit_tests/core/mcp/test_utils.py +++ b/api/tests/unit_tests/core/mcp/test_utils.py @@ -32,7 +32,7 @@ class TestConstants: class TestCreateSSRFProxyMCPHTTPClient: """Test create_ssrf_proxy_mcp_http_client function.""" - @patch("core.mcp.utils.dify_config") + @patch("core.mcp.utils.dify_config", autospec=True) def test_create_client_with_all_url_proxy(self, mock_config): """Test client creation with SSRF_PROXY_ALL_URL configured.""" mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080" @@ -50,7 +50,7 @@ class TestCreateSSRFProxyMCPHTTPClient: # Clean up client.close() - @patch("core.mcp.utils.dify_config") + @patch("core.mcp.utils.dify_config", autospec=True) def test_create_client_with_http_https_proxies(self, mock_config): """Test client creation with separate HTTP/HTTPS proxies.""" mock_config.SSRF_PROXY_ALL_URL = None @@ -66,7 +66,7 @@ class TestCreateSSRFProxyMCPHTTPClient: # Clean up client.close() - @patch("core.mcp.utils.dify_config") + @patch("core.mcp.utils.dify_config", autospec=True) def test_create_client_without_proxy(self, mock_config): """Test client creation without proxy configuration.""" mock_config.SSRF_PROXY_ALL_URL = None @@ -88,7 +88,7 @@ class TestCreateSSRFProxyMCPHTTPClient: # Clean up client.close() - @patch("core.mcp.utils.dify_config") + @patch("core.mcp.utils.dify_config", autospec=True) def test_create_client_default_params(self, mock_config): """Test client creation with default parameters.""" mock_config.SSRF_PROXY_ALL_URL = None @@ -111,8 +111,8 @@ class TestCreateSSRFProxyMCPHTTPClient: class TestSSRFProxySSEConnect: """Test ssrf_proxy_sse_connect function.""" - @patch("core.mcp.utils.connect_sse") - @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + @patch("core.mcp.utils.connect_sse", autospec=True) + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True) def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse): """Test SSE connection with pre-configured client.""" # Setup mocks @@ -138,9 +138,9 @@ class TestSSRFProxySSEConnect: # Verify result assert result == mock_context - @patch("core.mcp.utils.connect_sse") - @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") - @patch("core.mcp.utils.dify_config") + @patch("core.mcp.utils.connect_sse", autospec=True) + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True) + @patch("core.mcp.utils.dify_config", autospec=True) def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse): """Test SSE connection without pre-configured client.""" # Setup config @@ -183,8 +183,8 @@ class TestSSRFProxySSEConnect: # Verify result assert result == mock_context - @patch("core.mcp.utils.connect_sse") - @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + @patch("core.mcp.utils.connect_sse", autospec=True) + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True) def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse): """Test SSE connection with custom timeout.""" # Setup mocks @@ -209,8 +209,8 @@ class TestSSRFProxySSEConnect: # Verify result assert result == mock_context - @patch("core.mcp.utils.connect_sse") - @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client") + @patch("core.mcp.utils.connect_sse", autospec=True) + @patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True) def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse): """Test SSE connection cleans up client on error.""" # Setup mocks @@ -227,7 +227,7 @@ class TestSSRFProxySSEConnect: # Verify client was cleaned up mock_client.close.assert_called_once() - @patch("core.mcp.utils.connect_sse") + @patch("core.mcp.utils.connect_sse", autospec=True) def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse): """Test SSE connection doesn't clean up provided client on error.""" # Setup mocks diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py index cfdeef6a8d..09d527cb12 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ b/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -103,16 +103,16 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults(): assert result.system_fingerprint is None -def test__normalize_non_stream_plugin_result__closes_chunk_iterator(): +def test__normalize_non_stream_plugin_result__accumulates_all_chunks(): + """All chunks are accumulated from the iterator.""" prompt_messages = [UserPromptMessage(content="hi")] - chunk = _make_chunk(content="hello", usage=LLMUsage.empty_usage()) closed: list[bool] = [] def _chunk_iter(): try: - yield chunk - yield _make_chunk(content="ignored", usage=LLMUsage.empty_usage()) + yield _make_chunk(content="hello", usage=LLMUsage.empty_usage()) + yield _make_chunk(content=" world", usage=LLMUsage.empty_usage()) finally: closed.append(True) @@ -122,5 +122,5 @@ def test__normalize_non_stream_plugin_result__closes_chunk_iterator(): result=_chunk_iter(), ) - assert result.message.content == "hello" + assert result.message.content == "hello world" assert closed == [True] diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index 1a577f9b7f..e61cde22e7 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -324,7 +324,7 @@ class TestOpenAIModeration: with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"): OpenAIModeration.validate_config("test-tenant", config) - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API returns no violations.""" # Mock the model manager and instance @@ -341,7 +341,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API detects violations.""" # Mock the model manager to return violation @@ -358,7 +358,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test that query is included in moderation check with special key.""" mock_instance = MagicMock() @@ -385,7 +385,7 @@ class TestOpenAIModeration: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock): """Test input moderation when inputs_config is disabled.""" config = { @@ -400,7 +400,7 @@ class TestOpenAIModeration: # Should not call the API when disabled mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API returns no violations.""" mock_instance = MagicMock() @@ -414,7 +414,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Response blocked by moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API detects violations.""" mock_instance = MagicMock() @@ -427,7 +427,7 @@ class TestOpenAIModeration: assert result.flagged is True assert result.action == ModerationAction.DIRECT_OUTPUT - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock): """Test output moderation when outputs_config is disabled.""" config = { @@ -441,7 +441,7 @@ class TestOpenAIModeration: assert result.flagged is False mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_model_manager_called_with_correct_params( self, mock_model_manager: Mock, openai_moderation: OpenAIModeration ): @@ -494,7 +494,7 @@ class TestModerationRuleStructure: class TestModerationFactoryIntegration: """Test suite for ModerationFactory integration.""" - @patch("core.moderation.factory.code_based_extension") + @patch("core.moderation.factory.code_based_extension", autospec=True) def test_factory_delegates_to_extension(self, mock_extension: Mock): """Test ModerationFactory delegates to extension system.""" from core.moderation.factory import ModerationFactory @@ -518,7 +518,7 @@ class TestModerationFactoryIntegration: assert result.flagged is False mock_instance.moderation_for_inputs.assert_called_once() - @patch("core.moderation.factory.code_based_extension") + @patch("core.moderation.factory.code_based_extension", autospec=True) def test_factory_validate_config_delegates(self, mock_extension: Mock): """Test ModerationFactory.validate_config delegates to extension.""" from core.moderation.factory import ModerationFactory @@ -629,7 +629,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "Custom output blocked message" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI input violations.""" mock_instance = MagicMock() @@ -650,7 +650,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "OpenAI input blocked" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI output violations.""" mock_instance = MagicMock() @@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced: - Performance considerations """ - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_openai_api_timeout_handling(self, mock_model_manager: Mock): """ Test graceful handling of OpenAI API timeouts. @@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(TimeoutError): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock): """ Test handling of OpenAI API rate limit errors. @@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(Exception, match="Rate limit exceeded"): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock): """ Test OpenAI moderation with multiple input fields. @@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_openai_empty_text_handling(self, mock_model_manager: Mock): """ Test OpenAI moderation with empty text inputs. @@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced: assert result.flagged is False mock_instance.invoke_moderation.assert_called_once() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager") + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock): """ Test that ModelManager fetches a fresh model instance on each call. diff --git a/api/tests/unit_tests/core/plugin/test_endpoint_client.py b/api/tests/unit_tests/core/plugin/test_endpoint_client.py index 53056ee42a..48e30e9c2f 100644 --- a/api/tests/unit_tests/core/plugin/test_endpoint_client.py +++ b/api/tests/unit_tests/core/plugin/test_endpoint_client.py @@ -64,7 +64,7 @@ class TestPluginEndpointClientDelete: "data": True, } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = endpoint_client.delete_endpoint( tenant_id=tenant_id, @@ -102,7 +102,7 @@ class TestPluginEndpointClientDelete: ), } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = endpoint_client.delete_endpoint( tenant_id=tenant_id, @@ -139,7 +139,7 @@ class TestPluginEndpointClientDelete: ), } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginDaemonInternalServerError) as exc_info: endpoint_client.delete_endpoint( @@ -174,7 +174,7 @@ class TestPluginEndpointClientDelete: "message": '{"error_type": "PluginDaemonInternalServerError", "message": "Record Not Found"}', } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = endpoint_client.delete_endpoint( tenant_id=tenant_id, @@ -222,7 +222,7 @@ class TestPluginEndpointClientDelete: ), } - with patch("httpx.request") as mock_request: + with patch("httpx.request", autospec=True) as mock_request: # Act - first call mock_request.return_value = mock_response_success result1 = endpoint_client.delete_endpoint( @@ -266,7 +266,7 @@ class TestPluginEndpointClientDelete: "message": '{"error_type": "PluginDaemonUnauthorizedError", "message": "unauthorized access"}', } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(Exception) as exc_info: endpoint_client.delete_endpoint( diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 9e911e1fce..9e871fcb74 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -114,7 +114,7 @@ class TestPluginRuntimeExecution: mock_response.status_code = 200 mock_response.json.return_value = {"result": "success"} - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act response = plugin_client._request("GET", "plugin/test-tenant/management/list") @@ -132,7 +132,7 @@ class TestPluginRuntimeExecution: mock_response = MagicMock() mock_response.status_code = 200 - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request("GET", "plugin/test-tenant/test") @@ -143,7 +143,7 @@ class TestPluginRuntimeExecution: def test_request_connection_error(self, plugin_client, mock_config): """Test handling of connection errors during request.""" # Arrange - with patch("httpx.request", side_effect=httpx.RequestError("Connection failed")): + with patch("httpx.request", side_effect=httpx.RequestError("Connection failed"), autospec=True): # Act & Assert with pytest.raises(PluginDaemonInnerError) as exc_info: plugin_client._request("GET", "plugin/test-tenant/test") @@ -182,7 +182,7 @@ class TestPluginRuntimeSandboxIsolation: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": True} - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request("GET", "plugin/test-tenant/test") @@ -201,7 +201,7 @@ class TestPluginRuntimeSandboxIsolation: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": {"result": "isolated_execution"}} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = plugin_client._request_with_plugin_daemon_response( "POST", "plugin/test-tenant/dispatch/tool/invoke", TestResponse, data={"tool": "test"} @@ -218,7 +218,7 @@ class TestPluginRuntimeSandboxIsolation: error_message = json.dumps({"error_type": "PluginDaemonUnauthorizedError", "message": "Unauthorized access"}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginDaemonUnauthorizedError) as exc_info: plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) @@ -234,7 +234,7 @@ class TestPluginRuntimeSandboxIsolation: ) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginPermissionDeniedError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool) @@ -272,7 +272,7 @@ class TestPluginRuntimeResourceLimits: mock_response = MagicMock() mock_response.status_code = 200 - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request("GET", "plugin/test-tenant/test") @@ -283,7 +283,7 @@ class TestPluginRuntimeResourceLimits: def test_timeout_error_handling(self, plugin_client, mock_config): """Test handling of timeout errors.""" # Arrange - with patch("httpx.request", side_effect=httpx.TimeoutException("Request timeout")): + with patch("httpx.request", side_effect=httpx.TimeoutException("Request timeout"), autospec=True): # Act & Assert with pytest.raises(PluginDaemonInnerError) as exc_info: plugin_client._request("GET", "plugin/test-tenant/test") @@ -292,7 +292,7 @@ class TestPluginRuntimeResourceLimits: def test_streaming_request_timeout(self, plugin_client, mock_config): """Test timeout handling for streaming requests.""" # Arrange - with patch("httpx.stream", side_effect=httpx.TimeoutException("Stream timeout")): + with patch("httpx.stream", side_effect=httpx.TimeoutException("Stream timeout"), autospec=True): # Act & Assert with pytest.raises(PluginDaemonInnerError) as exc_info: list(plugin_client._stream_request("POST", "plugin/test-tenant/stream")) @@ -308,7 +308,7 @@ class TestPluginRuntimeResourceLimits: ) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginDaemonInternalServerError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool) @@ -352,7 +352,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(InvokeRateLimitError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool) @@ -371,7 +371,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(InvokeAuthorizationError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool) @@ -390,7 +390,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(InvokeBadRequestError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool) @@ -409,7 +409,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(InvokeConnectionError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool) @@ -428,7 +428,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(InvokeServerUnavailableError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool) @@ -446,7 +446,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(CredentialsValidateFailedError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/validate", bool) @@ -462,7 +462,7 @@ class TestPluginRuntimeErrorHandling: ) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginNotFoundError) as exc_info: plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/get", bool) @@ -478,7 +478,7 @@ class TestPluginRuntimeErrorHandling: ) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginUniqueIdentifierError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/install", bool) @@ -494,7 +494,7 @@ class TestPluginRuntimeErrorHandling: ) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginDaemonBadRequestError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool) @@ -508,7 +508,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "PluginDaemonNotFoundError", "message": "Resource not found"}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginDaemonNotFoundError) as exc_info: plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/resource", bool) @@ -526,7 +526,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "PluginInvokeError", "message": invoke_error_message}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginInvokeError) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool) @@ -540,7 +540,7 @@ class TestPluginRuntimeErrorHandling: error_message = json.dumps({"error_type": "UnknownErrorType", "message": "Unknown error occurred"}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(Exception) as exc_info: plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool) @@ -555,7 +555,7 @@ class TestPluginRuntimeErrorHandling: "Server Error", request=MagicMock(), response=mock_response ) - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(httpx.HTTPStatusError): plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) @@ -567,7 +567,7 @@ class TestPluginRuntimeErrorHandling: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(ValueError) as exc_info: plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) @@ -610,7 +610,7 @@ class TestPluginRuntimeCommunication: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": {"value": "test", "count": 42}} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = plugin_client._request_with_plugin_daemon_response( "POST", "plugin/test-tenant/test", TestModel, data={"input": "data"} @@ -637,7 +637,7 @@ class TestPluginRuntimeCommunication: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -667,7 +667,7 @@ class TestPluginRuntimeCommunication: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -689,7 +689,7 @@ class TestPluginRuntimeCommunication: def test_streaming_connection_error(self, plugin_client, mock_config): """Test connection error during streaming.""" # Arrange - with patch("httpx.stream", side_effect=httpx.RequestError("Stream connection failed")): + with patch("httpx.stream", side_effect=httpx.RequestError("Stream connection failed"), autospec=True): # Act & Assert with pytest.raises(PluginDaemonInnerError) as exc_info: list(plugin_client._stream_request("POST", "plugin/test-tenant/stream")) @@ -707,7 +707,7 @@ class TestPluginRuntimeCommunication: mock_response.status_code = 200 mock_response.json.return_value = {"status": "success", "data": {"key": "value"}} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = plugin_client._request_with_model("GET", "plugin/test-tenant/direct", DirectModel) @@ -732,7 +732,7 @@ class TestPluginRuntimeCommunication: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -764,7 +764,7 @@ class TestPluginRuntimeCommunication: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -814,7 +814,7 @@ class TestPluginToolManagerIntegration: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -844,7 +844,7 @@ class TestPluginToolManagerIntegration: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -868,7 +868,7 @@ class TestPluginToolManagerIntegration: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -892,7 +892,7 @@ class TestPluginToolManagerIntegration: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -945,7 +945,7 @@ class TestPluginInstallerIntegration: }, } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = installer.list_plugins("test-tenant") @@ -959,7 +959,7 @@ class TestPluginInstallerIntegration: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": True} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = installer.uninstall("test-tenant", "plugin-installation-id") @@ -973,7 +973,7 @@ class TestPluginInstallerIntegration: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": True} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = installer.fetch_plugin_by_identifier("test-tenant", "plugin-identifier") @@ -1012,7 +1012,7 @@ class TestPluginRuntimeEdgeCases: mock_response.status_code = 200 mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(ValueError): plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) @@ -1025,7 +1025,7 @@ class TestPluginRuntimeEdgeCases: # Missing required fields in response mock_response.json.return_value = {"invalid": "structure"} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(ValueError): plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) @@ -1041,7 +1041,7 @@ class TestPluginRuntimeEdgeCases: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1065,7 +1065,7 @@ class TestPluginRuntimeEdgeCases: mock_response = MagicMock() mock_response.status_code = 200 - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request("POST", "plugin/test-tenant/upload", data=b"binary data") @@ -1081,7 +1081,7 @@ class TestPluginRuntimeEdgeCases: files = {"file": ("test.txt", b"file content", "text/plain")} - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request("POST", "plugin/test-tenant/upload", files=files) @@ -1095,7 +1095,7 @@ class TestPluginRuntimeEdgeCases: mock_response = MagicMock() mock_response.iter_lines.return_value = [] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1115,7 +1115,7 @@ class TestPluginRuntimeEdgeCases: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act & Assert @@ -1136,7 +1136,7 @@ class TestPluginRuntimeEdgeCases: mock_response.status_code = 200 mock_response.json.return_value = {"code": -1, "message": "Plain text error message", "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(ValueError) as exc_info: plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) @@ -1174,7 +1174,7 @@ class TestPluginRuntimeAdvancedScenarios: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": True} - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act for i in range(5): result = plugin_client._request_with_plugin_daemon_response("GET", f"plugin/test-tenant/test/{i}", bool) @@ -1203,7 +1203,7 @@ class TestPluginRuntimeAdvancedScenarios: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": complex_data} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = plugin_client._request_with_plugin_daemon_response( "POST", "plugin/test-tenant/complex", ComplexModel @@ -1231,7 +1231,7 @@ class TestPluginRuntimeAdvancedScenarios: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1262,7 +1262,7 @@ class TestPluginRuntimeAdvancedScenarios: mock_response.status_code = 200 return mock_response - with patch("httpx.request", side_effect=side_effect): + with patch("httpx.request", side_effect=side_effect, autospec=True): # Act & Assert - First two calls should fail with pytest.raises(PluginDaemonInnerError): plugin_client._request("GET", "plugin/test-tenant/test") @@ -1286,7 +1286,7 @@ class TestPluginRuntimeAdvancedScenarios: mock_response = MagicMock() mock_response.status_code = 200 - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request("GET", "plugin/test-tenant/test", headers=custom_headers) @@ -1312,7 +1312,7 @@ class TestPluginRuntimeAdvancedScenarios: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1359,7 +1359,7 @@ class TestPluginRuntimeSecurityAndValidation: mock_response = MagicMock() mock_response.status_code = 200 - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request("GET", "plugin/test-tenant/test") @@ -1381,7 +1381,7 @@ class TestPluginRuntimeSecurityAndValidation: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": True} - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request_with_plugin_daemon_response( "POST", @@ -1403,7 +1403,7 @@ class TestPluginRuntimeSecurityAndValidation: error_message = json.dumps({"error_type": "PluginDaemonUnauthorizedError", "message": "Invalid API key"}) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginDaemonUnauthorizedError) as exc_info: plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) @@ -1424,7 +1424,7 @@ class TestPluginRuntimeSecurityAndValidation: ) mock_response.json.return_value = {"code": -1, "message": error_message, "data": None} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert with pytest.raises(PluginDaemonBadRequestError) as exc_info: plugin_client._request_with_plugin_daemon_response( @@ -1438,7 +1438,7 @@ class TestPluginRuntimeSecurityAndValidation: mock_response = MagicMock() mock_response.status_code = 200 - with patch("httpx.request", return_value=mock_response) as mock_request: + with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request: # Act plugin_client._request( "POST", "plugin/test-tenant/test", headers={"Content-Type": "application/json"}, data={"key": "value"} @@ -1489,7 +1489,7 @@ class TestPluginRuntimePerformanceScenarios: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1524,7 +1524,7 @@ class TestPluginRuntimePerformanceScenarios: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act - Process chunks one by one @@ -1539,7 +1539,7 @@ class TestPluginRuntimePerformanceScenarios: def test_timeout_with_slow_response(self, plugin_client, mock_config): """Test timeout handling with slow response simulation.""" # Arrange - with patch("httpx.request", side_effect=httpx.TimeoutException("Request timed out after 30s")): + with patch("httpx.request", side_effect=httpx.TimeoutException("Request timed out after 30s"), autospec=True): # Act & Assert with pytest.raises(PluginDaemonInnerError) as exc_info: plugin_client._request("GET", "plugin/test-tenant/slow-endpoint") @@ -1554,7 +1554,7 @@ class TestPluginRuntimePerformanceScenarios: request_results = [] - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act - Simulate 10 concurrent requests for i in range(10): result = plugin_client._request_with_plugin_daemon_response( @@ -1612,7 +1612,7 @@ class TestPluginToolManagerAdvanced: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1641,7 +1641,7 @@ class TestPluginToolManagerAdvanced: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1673,7 +1673,7 @@ class TestPluginToolManagerAdvanced: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1704,7 +1704,7 @@ class TestPluginToolManagerAdvanced: mock_response = MagicMock() mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data] - with patch("httpx.stream") as mock_stream: + with patch("httpx.stream", autospec=True) as mock_stream: mock_stream.return_value.__enter__.return_value = mock_response # Act @@ -1770,7 +1770,7 @@ class TestPluginInstallerAdvanced: }, } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = installer.upload_pkg("test-tenant", plugin_package, verify_signature=False) @@ -1788,7 +1788,7 @@ class TestPluginInstallerAdvanced: "data": {"content": "# Plugin README\n\nThis is a test plugin.", "language": "en"}, } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en") @@ -1807,7 +1807,7 @@ class TestPluginInstallerAdvanced: mock_response.raise_for_status = raise_for_status - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert - Should raise HTTPStatusError for 404 with pytest.raises(httpx.HTTPStatusError): installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en") @@ -1826,7 +1826,7 @@ class TestPluginInstallerAdvanced: }, } - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = installer.list_plugins_with_total("test-tenant", page=2, page_size=20) @@ -1848,7 +1848,7 @@ class TestPluginInstallerAdvanced: mock_response.status_code = 200 mock_response.json.return_value = {"code": 0, "message": "", "data": [True, False]} - with patch("httpx.request", return_value=mock_response): + with patch("httpx.request", return_value=mock_response, autospec=True): # Act result = installer.check_tools_existence("test-tenant", provider_ids) diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index f07e55d534..1d25639343 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -142,7 +142,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) - with patch("core.workflow.file.file_manager.to_prompt_message_content") as mock_get_encoded_string: + with patch("core.workflow.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: mock_get_encoded_string.return_value = ImagePromptMessageContent( url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" ) diff --git a/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py index 3167a9a301..47222a23a2 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_pdf_extractor.py @@ -83,7 +83,7 @@ def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, exp extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") # We need to handle the import inside _extract_images - with patch("pypdfium2.raw") as mock_raw: + with patch("pypdfium2.raw", autospec=True) as mock_raw: mock_raw.FPDF_PAGEOBJ_IMAGE = 1 result = extractor._extract_images(mock_page) @@ -115,7 +115,7 @@ def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_sid extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") - with patch("pypdfium2.raw") as mock_raw: + with patch("pypdfium2.raw", autospec=True) as mock_raw: mock_raw.FPDF_PAGEOBJ_IMAGE = 1 result = extractor._extract_images(mock_page) @@ -133,11 +133,11 @@ def test_extract_calls_extract_images(mock_dependencies, monkeypatch): mock_text_page.get_text_range.return_value = "Page text content" mock_page.get_textpage.return_value = mock_text_page - with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc): + with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc, autospec=True): # Mock Blob mock_blob = MagicMock() mock_blob.source = "test.pdf" - with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob): + with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob, autospec=True): extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") # Mock _extract_images to return a known string @@ -175,7 +175,7 @@ def test_extract_images_failures(mock_dependencies): extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1") - with patch("pypdfium2.raw") as mock_raw: + with patch("pypdfium2.raw", autospec=True) as mock_raw: mock_raw.FPDF_PAGEOBJ_IMAGE = 1 result = extractor._extract_images(mock_page) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 3cecc92c16..e4597e7f8c 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -52,7 +52,7 @@ class TestRerankModelRunner: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -397,19 +397,19 @@ class TestWeightRerankRunner: @pytest.fixture def mock_model_manager(self): """Mock ModelManager for embedding model.""" - with patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager: + with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager: yield mock_manager @pytest.fixture def mock_cache_embedding(self): """Mock CacheEmbedding for vector operations.""" - with patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache: + with patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache: yield mock_cache @pytest.fixture def mock_jieba_handler(self): """Mock JiebaKeywordTableHandler for keyword extraction.""" - with patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba: + with patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba: yield mock_jieba @pytest.fixture @@ -914,7 +914,7 @@ class TestRerankIntegration: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1026,7 +1026,7 @@ class TestRerankEdgeCases: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1295,9 +1295,9 @@ class TestRerankEdgeCases: # Mock dependencies with ( - patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager, - patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache, + patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, + patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() mock_handler.extract_keywords.return_value = ["test"] @@ -1367,7 +1367,7 @@ class TestRerankPerformance: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1441,9 +1441,9 @@ class TestRerankPerformance: runner = WeightRerankRunner(tenant_id="tenant123", weights=weights) with ( - patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager, - patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache, + patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, + patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() # Track keyword extraction calls @@ -1484,7 +1484,7 @@ class TestRerankErrorHandling: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1592,9 +1592,9 @@ class TestRerankErrorHandling: runner = WeightRerankRunner(tenant_id="tenant123", weights=weights) with ( - patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager, - patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache, + patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, + patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() mock_handler.extract_keywords.return_value = ["test"] diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index 30f51902ef..7f1e2c5e5b 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -48,7 +48,7 @@ class TestRepositoryFactory: import_string("invalidpath") assert "doesn't look like a module path" in str(exc_info.value) - @patch("core.repositories.factory.dify_config") + @patch("core.repositories.factory.dify_config", autospec=True) def test_create_workflow_execution_repository_success(self, mock_config): """Test successful WorkflowExecutionRepository creation.""" # Setup mock configuration @@ -66,7 +66,7 @@ class TestRepositoryFactory: mock_repository_class.return_value = mock_repository_instance # Mock import_string - with patch("core.repositories.factory.import_string", return_value=mock_repository_class): + with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True): result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, @@ -83,7 +83,7 @@ class TestRepositoryFactory: ) assert result is mock_repository_instance - @patch("core.repositories.factory.dify_config") + @patch("core.repositories.factory.dify_config", autospec=True) def test_create_workflow_execution_repository_import_error(self, mock_config): """Test WorkflowExecutionRepository creation with import error.""" # Setup mock configuration with invalid class path @@ -101,7 +101,7 @@ class TestRepositoryFactory: ) assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) - @patch("core.repositories.factory.dify_config") + @patch("core.repositories.factory.dify_config", autospec=True) def test_create_workflow_execution_repository_instantiation_error(self, mock_config): """Test WorkflowExecutionRepository creation with instantiation error.""" # Setup mock configuration @@ -115,7 +115,7 @@ class TestRepositoryFactory: mock_repository_class.side_effect = Exception("Instantiation failed") # Mock import_string to return a failing class - with patch("core.repositories.factory.import_string", return_value=mock_repository_class): + with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True): with pytest.raises(RepositoryImportError) as exc_info: DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, @@ -125,7 +125,7 @@ class TestRepositoryFactory: ) assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) - @patch("core.repositories.factory.dify_config") + @patch("core.repositories.factory.dify_config", autospec=True) def test_create_workflow_node_execution_repository_success(self, mock_config): """Test successful WorkflowNodeExecutionRepository creation.""" # Setup mock configuration @@ -143,7 +143,7 @@ class TestRepositoryFactory: mock_repository_class.return_value = mock_repository_instance # Mock import_string - with patch("core.repositories.factory.import_string", return_value=mock_repository_class): + with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True): result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, @@ -160,7 +160,7 @@ class TestRepositoryFactory: ) assert result is mock_repository_instance - @patch("core.repositories.factory.dify_config") + @patch("core.repositories.factory.dify_config", autospec=True) def test_create_workflow_node_execution_repository_import_error(self, mock_config): """Test WorkflowNodeExecutionRepository creation with import error.""" # Setup mock configuration with invalid class path @@ -178,7 +178,7 @@ class TestRepositoryFactory: ) assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) - @patch("core.repositories.factory.dify_config") + @patch("core.repositories.factory.dify_config", autospec=True) def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): """Test WorkflowNodeExecutionRepository creation with instantiation error.""" # Setup mock configuration @@ -192,7 +192,7 @@ class TestRepositoryFactory: mock_repository_class.side_effect = Exception("Instantiation failed") # Mock import_string to return a failing class - with patch("core.repositories.factory.import_string", return_value=mock_repository_class): + with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True): with pytest.raises(RepositoryImportError) as exc_info: DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, @@ -208,7 +208,7 @@ class TestRepositoryFactory: error = RepositoryImportError(error_message) assert str(error) == error_message - @patch("core.repositories.factory.dify_config") + @patch("core.repositories.factory.dify_config", autospec=True) def test_create_with_engine_instead_of_sessionmaker(self, mock_config): """Test repository creation with Engine instead of sessionmaker.""" # Setup mock configuration @@ -226,7 +226,7 @@ class TestRepositoryFactory: mock_repository_class.return_value = mock_repository_instance # Mock import_string - with patch("core.repositories.factory.import_string", return_value=mock_repository_class): + with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True): result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_engine, # Using Engine instead of sessionmaker user=mock_user, diff --git a/api/tests/unit_tests/core/schemas/test_resolver.py b/api/tests/unit_tests/core/schemas/test_resolver.py index 239ee85346..90827de894 100644 --- a/api/tests/unit_tests/core/schemas/test_resolver.py +++ b/api/tests/unit_tests/core/schemas/test_resolver.py @@ -196,7 +196,7 @@ class TestSchemaResolver: resolved1 = resolve_dify_schema_refs(schema) # Mock the registry to return different data - with patch.object(self.registry, "get_schema") as mock_get: + with patch.object(self.registry, "get_schema", autospec=True) as mock_get: mock_get.return_value = {"type": "different"} # Second resolution should use cache @@ -445,7 +445,7 @@ class TestSchemaResolverClass: # Second resolver should use the same cache resolver2 = SchemaResolver() - with patch.object(resolver2.registry, "get_schema") as mock_get: + with patch.object(resolver2.registry, "get_schema", autospec=True) as mock_get: result2 = resolver2.resolve(schema) # Should not call registry since it's in cache mock_get.assert_not_called() diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index bb9e381834..a9af8bea1d 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -3,7 +3,10 @@ import dataclasses from pydantic import BaseModel from core.helper import encrypter -from core.variables.segments import ( +from core.workflow.file import File, FileTransferMethod, FileType +from core.workflow.runtime import VariablePool +from core.workflow.system_variable import SystemVariable +from core.workflow.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -19,8 +22,8 @@ from core.variables.segments import ( StringSegment, get_segment_discriminator, ) -from core.variables.types import SegmentType -from core.variables.variables import ( +from core.workflow.variables.types import SegmentType +from core.workflow.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, ArrayNumberVariable, @@ -35,9 +38,6 @@ from core.variables.variables import ( StringVariable, Variable, ) -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable def test_segment_group_to_text(): diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index 3bfc5a957f..e28fed187b 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,6 +1,6 @@ import pytest -from core.variables.types import ArrayValidation, SegmentType +from core.workflow.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 0ec0fc536e..52e5dd180c 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -10,8 +10,10 @@ from typing import Any import pytest -from core.variables.segment_group import SegmentGroup -from core.variables.segments import ( +from core.workflow.file.enums import FileTransferMethod, FileType +from core.workflow.file.models import File +from core.workflow.variables.segment_group import SegmentGroup +from core.workflow.variables.segments import ( ArrayFileSegment, BooleanSegment, FileSegment, @@ -20,9 +22,7 @@ from core.variables.segments import ( ObjectSegment, StringSegment, ) -from core.variables.types import ArrayValidation, SegmentType -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File +from core.workflow.variables.types import ArrayValidation, SegmentType def create_test_file( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index fb4b18b57a..6fc162e533 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from core.variables import ( +from core.workflow.variables import ( ArrayFileVariable, ArrayVariable, FloatVariable, @@ -11,7 +11,7 @@ from core.variables import ( SegmentType, StringVariable, ) -from core.variables.variables import VariableBase +from core.workflow.variables.variables import VariableBase def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py b/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py index a809b29552..abfb1e85ca 100644 --- a/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_flask_app_context.py @@ -138,8 +138,8 @@ class TestFlaskExecutionContext: class TestCaptureFlaskContext: """Test capture_flask_context function.""" - @patch("context.flask_app_context.current_app") - @patch("context.flask_app_context.g") + @patch("context.flask_app_context.current_app", autospec=True) + @patch("context.flask_app_context.g", autospec=True) def test_capture_flask_context_captures_app(self, mock_g, mock_current_app): """Test capture_flask_context captures Flask app.""" mock_app = MagicMock() @@ -152,8 +152,8 @@ class TestCaptureFlaskContext: assert ctx._flask_app == mock_app - @patch("context.flask_app_context.current_app") - @patch("context.flask_app_context.g") + @patch("context.flask_app_context.current_app", autospec=True) + @patch("context.flask_app_context.g", autospec=True) def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app): """Test capture_flask_context captures user from Flask g object.""" mock_app = MagicMock() @@ -170,7 +170,7 @@ class TestCaptureFlaskContext: assert ctx.user == mock_user - @patch("context.flask_app_context.current_app") + @patch("context.flask_app_context.current_app", autospec=True) def test_capture_flask_context_with_explicit_user(self, mock_current_app): """Test capture_flask_context uses explicit user parameter.""" mock_app = MagicMock() @@ -186,7 +186,7 @@ class TestCaptureFlaskContext: assert ctx.user == explicit_user - @patch("context.flask_app_context.current_app") + @patch("context.flask_app_context.current_app", autospec=True) def test_capture_flask_context_captures_contextvars(self, mock_current_app): """Test capture_flask_context captures context variables.""" mock_app = MagicMock() @@ -267,7 +267,7 @@ class TestFlaskExecutionContextIntegration: # Verify app context was entered assert mock_flask_app.app_context.called - @patch("context.flask_app_context.g") + @patch("context.flask_app_context.g", autospec=True) def test_enter_restores_user_in_g(self, mock_g, mock_flask_app): """Test that enter restores user in Flask g object.""" mock_user = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 1b6d03e36a..8d49394653 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -138,10 +138,10 @@ class TestGraphRuntimeState: _ = state.response_coordinator mock_graph = MagicMock() - with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls: - coordinator_instance = MagicMock() - coordinator_cls.return_value = coordinator_instance - + with patch( + "core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True + ) as coordinator_cls: + coordinator_instance = coordinator_cls.return_value state.configure(graph=mock_graph) assert state.response_coordinator is coordinator_instance @@ -204,7 +204,7 @@ class TestGraphRuntimeState: mock_graph = MagicMock() stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub): + with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub, autospec=True): state.attach_graph(mock_graph) stub.state = "configured" @@ -230,7 +230,7 @@ class TestGraphRuntimeState: assert restored_execution.started is True new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub): + with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): restored.attach_graph(mock_graph) assert new_stub.state == "configured" @@ -251,14 +251,14 @@ class TestGraphRuntimeState: mock_graph = MagicMock() original_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub): + with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub, autospec=True): state.attach_graph(mock_graph) original_stub.state = "configured" snapshot = state.dumps() new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub): + with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) restored.attach_graph(mock_graph) restored.loads(snapshot) diff --git a/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py b/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py index be165bf1c1..3f47610312 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py +++ b/api/tests/unit_tests/core/workflow/entities/test_private_workflow_pause.py @@ -63,7 +63,7 @@ class TestPrivateWorkflowPauseEntity: assert entity.resumed_at is None - @patch("repositories.sqlalchemy_api_workflow_run_repository.storage") + @patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True) def test_get_state_first_call(self, mock_storage): """Test get_state loads from storage on first call.""" state_data = b'{"test": "data", "step": 5}' @@ -81,7 +81,7 @@ class TestPrivateWorkflowPauseEntity: mock_storage.load.assert_called_once_with("test-state-key") assert entity._cached_state == state_data - @patch("repositories.sqlalchemy_api_workflow_run_repository.storage") + @patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True) def test_get_state_cached_call(self, mock_storage): """Test get_state returns cached data on subsequent calls.""" state_data = b'{"test": "data", "step": 5}' @@ -102,7 +102,7 @@ class TestPrivateWorkflowPauseEntity: # Storage should only be called once mock_storage.load.assert_called_once_with("test-state-key") - @patch("repositories.sqlalchemy_api_workflow_run_repository.storage") + @patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True) def test_get_state_with_pre_cached_data(self, mock_storage): """Test get_state returns pre-cached data.""" state_data = b'{"test": "data", "step": 5}' @@ -125,7 +125,7 @@ class TestPrivateWorkflowPauseEntity: # Test with binary data that's not valid JSON binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe" - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: + with patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True) as mock_storage: mock_storage.load.return_value = binary_data mock_pause_model = MagicMock(spec=WorkflowPauseModel) diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index 18f6753b05..d4254df319 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -1,10 +1,10 @@ -from core.variables.segments import ( +from core.workflow.runtime import VariablePool +from core.workflow.variables.segments import ( BooleanSegment, IntegerSegment, NoneSegment, StringSegment, ) -from core.workflow.runtime import VariablePool class TestVariablePoolGetAndNestedAttribute: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index f33fd0deeb..db9b977e4a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,7 +3,6 @@ import json from unittest.mock import MagicMock -from core.variables import IntegerVariable, StringVariable from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.entities.commands import ( AbortCommand, @@ -12,6 +11,7 @@ from core.workflow.graph_engine.entities.commands import ( UpdateVariablesCommand, VariableUpdate, ) +from core.workflow.variables import IntegerVariable, StringVariable class TestRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 35a234be0b..903800ce88 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -90,14 +90,14 @@ def mock_tool_node(): @pytest.fixture def mock_is_instrument_flag_enabled_false(): """Mock is_instrument_flag_enabled to return False.""" - with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False): + with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False, autospec=True): yield @pytest.fixture def mock_is_instrument_flag_enabled_true(): """Mock is_instrument_flag_enabled to return True.""" - with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True): + with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True, autospec=True): yield diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 1af5a80a56..6c3700ea2b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -4,7 +4,6 @@ import time from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.variables import IntegerVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph import Graph @@ -20,6 +19,7 @@ from core.workflow.graph_engine.entities.commands import ( from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent from core.workflow.nodes.start.start_node import StartNode from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.variables import IntegerVariable, StringVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 186f8a8425..b862cbe89e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -112,7 +112,6 @@ class MockNodeFactory(DifyNodeFactory): graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, code_executor=self._code_executor, - code_providers=self._code_providers, code_limits=self._code_limits, ) elif node_type == NodeType.HTTP_REQUEST: @@ -123,6 +122,9 @@ class MockNodeFactory(DifyNodeFactory): graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, http_request_config=self._http_request_config, + http_client=self._http_request_http_client, + tool_file_manager_factory=self._http_request_tool_file_manager_factory, + file_manager=self._http_request_file_manager, ) elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}: mock_instance = mock_class( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 71e8a9d863..5aed463a45 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,6 +10,7 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock +from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -44,9 +45,10 @@ class MockNodeMixin: mock_config: Optional["MockConfig"] = None, **kwargs: Any, ): - if isinstance(self, (LLMNode, QuestionClassifierNode)): + if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) + kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) super().__init__( id=id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index e760d7b3d3..6c4178dfed 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -215,9 +215,9 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_with_variables(self): """Test that MockTemplateTransformNode processes templates with variables.""" - from core.variables import StringVariable from core.workflow.entities import GraphInitParams from core.workflow.runtime import GraphRuntimeState, VariablePool + from core.workflow.variables import StringVariable # Create test parameters graph_init_params = GraphInitParams( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index 53c6bc3d60..a93d03c87e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -9,11 +9,12 @@ This test validates that: """ import time -from unittest.mock import patch +from unittest.mock import MagicMock, patch from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.workflow.node_factory import DifyNodeFactory +from core.model_manager import ModelInstance from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph @@ -115,7 +116,10 @@ def test_parallel_streaming_workflow(): # Create node factory and graph node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + with patch.object( + DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True + ): + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) # Create the graph engine engine = GraphEngine( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py index 0b998034b1..6d2ce4cb71 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_stop_event.py @@ -378,7 +378,7 @@ class TestStopEventIntegration: class TestStopEventTimeoutBehavior: """Test stop_event behavior with join timeouts.""" - @patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread") + @patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True) def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock): """Test that Dispatcher uses 2s timeout instead of 10s.""" runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) @@ -405,7 +405,7 @@ class TestStopEventTimeoutBehavior: mock_thread_instance.join.assert_called_once_with(timeout=2.0) - @patch("core.workflow.graph_engine.worker_management.worker_pool.Worker") + @patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True) def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock): """Test that WorkerPool uses 2s timeout instead of 10s.""" runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index afa9265fcd..5cbb7cf36e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -21,15 +21,6 @@ from typing import Any from core.app.workflow.node_factory import DifyNodeFactory from core.tools.utils.yaml_utils import _load_yaml_file -from core.variables import ( - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, - ObjectVariable, - StringVariable, -) from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine, GraphEngineConfig @@ -41,6 +32,15 @@ from core.workflow.graph_events import ( ) from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectVariable, + StringVariable, +) from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory @@ -547,8 +547,22 @@ class TableTestRunner: """Run tests in parallel.""" results = [] + flask_app: Any = None + try: + from flask import current_app + + flask_app = current_app._get_current_object() # type: ignore[attr-defined] + except RuntimeError: + flask_app = None + + def _run_test_case_with_context(test_case: WorkflowTestCase) -> WorkflowTestResult: + if flask_app is None: + return self.run_test_case(test_case) + with flask_app.app_context(): + return self.run_test_case(test_case) + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - future_to_test = {executor.submit(self.run_test_case, tc): tc for tc in test_cases} + future_to_test = {executor.submit(_run_test_case_with_context, tc): tc for tc in test_cases} for future in as_completed(future_to_test): test_case = future_to_test[future] diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 2262d25a14..00c8cb3779 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,14 +1,13 @@ from configs import dify_config -from core.helper.code_executor.code_executor import CodeLanguage -from core.variables.types import SegmentType from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData from core.workflow.nodes.code.exc import ( CodeNodeError, DepthLimitError, OutputValidationError, ) from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.variables.types import SegmentType CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, @@ -438,7 +437,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node.init_node_data(data) + node._node_data = node._hydrate_node_data(data) assert node._node_data.title == "Test Node" assert node._node_data.code_language == CodeLanguage.PYTHON3 @@ -454,7 +453,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node.init_node_data(data) + node._node_data = node._hydrate_node_data(data) assert node._node_data.code_language == CodeLanguage.JAVASCRIPT diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py index d14a6ea69c..28d59c3568 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py @@ -1,9 +1,8 @@ import pytest from pydantic import ValidationError -from core.helper.code_executor.code_executor import CodeLanguage -from core.variables.types import SegmentType -from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData +from core.workflow.variables.types import SegmentType class TestCodeNodeDataOutput: diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py new file mode 100644 index 0000000000..584ed23e91 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -0,0 +1,93 @@ +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.nodes.datasource.datasource_node import DatasourceNode + + +class _VarSeg: + def __init__(self, v): + self.value = v + + +class _VarPool: + def __init__(self, mapping): + self._m = mapping + + def get(self, selector): + d = self._m + for k in selector: + d = d[k] + return _VarSeg(d) + + def add(self, *_args, **_kwargs): + pass + + +class _GraphState: + def __init__(self, var_pool): + self.variable_pool = var_pool + + +class _GraphParams: + tenant_id = "t1" + app_id = "app-1" + workflow_id = "wf-1" + graph_config = {} + user_id = "u1" + user_from = "account" + invoke_from = "debugger" + call_depth = 0 + + +def test_datasource_node_delegates_to_manager_stream(mocker): + # prepare sys variables + sys_vars = { + "sys": { + "datasource_type": "online_document", + "datasource_info": { + "workspace_id": "w", + "page": {"page_id": "pg", "type": "t"}, + "credential_id": "", + }, + } + } + var_pool = _VarPool(sys_vars) + gs = _GraphState(var_pool) + gp = _GraphParams() + + # stub manager class + class _Mgr: + @classmethod + def get_icon_url(cls, **_): + return "icon" + + @classmethod + def stream_node_events(cls, **_): + yield StreamChunkEvent(selector=["n", "text"], chunk="hi", is_final=False) + yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)) + + @classmethod + def get_upload_file_by_id(cls, **_): + raise AssertionError("not called") + + node = DatasourceNode( + id="n", + config={ + "id": "n", + "data": { + "type": "datasource", + "version": "1", + "title": "Datasource", + "provider_type": "plugin", + "provider_name": "p", + "plugin_id": "plug", + "datasource_name": "ds", + }, + }, + graph_init_params=gp, + graph_runtime_state=gs, + datasource_manager=_Mgr, + ) + + evts = list(node._run()) + assert isinstance(evts[0], StreamChunkEvent) + assert isinstance(evts[-1], StreamCompletedEvent) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index 65f4de8c1d..67da890eb2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,6 +1,8 @@ import pytest from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.file.file_manager import file_manager from core.workflow.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, @@ -59,6 +61,8 @@ def test_executor_with_json_body_and_number_variable(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Check the executor's data @@ -113,6 +117,8 @@ def test_executor_with_json_body_and_object_variable(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Check the executor's data @@ -169,6 +175,8 @@ def test_executor_with_json_body_and_nested_object_variable(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Check the executor's data @@ -213,6 +221,8 @@ def test_extract_selectors_from_template_with_newline(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) assert executor.params == [("test", "line1\nline2")] @@ -258,6 +268,8 @@ def test_executor_with_form_data(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Check the executor's data @@ -309,6 +321,8 @@ def test_init_headers(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=VariablePool(system_variables=SystemVariable.default()), + http_client=ssrf_proxy, + file_manager=file_manager, ) executor = create_executor("aa\n cc:") @@ -344,6 +358,8 @@ def test_init_params(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=VariablePool(system_variables=SystemVariable.default()), + http_client=ssrf_proxy, + file_manager=file_manager, ) # Test basic key-value pairs @@ -394,6 +410,8 @@ def test_empty_api_key_raises_error_bearer(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -419,6 +437,8 @@ def test_empty_api_key_raises_error_basic(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -444,6 +464,8 @@ def test_empty_api_key_raises_error_custom(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -469,6 +491,8 @@ def test_whitespace_only_api_key_raises_error(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -493,6 +517,8 @@ def test_valid_api_key_works(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Should not raise an error @@ -541,6 +567,8 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # The UUID should be preserved in full, not truncated @@ -586,6 +614,8 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # The UUID should be preserved in full @@ -625,6 +655,8 @@ def test_executor_with_json_body_preserves_numbers_and_strings(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) assert executor.json["count"] == 42 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 472718188f..cad0466809 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -5,8 +5,11 @@ import httpx import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.ssrf_proxy import ssrf_proxy +from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus +from core.workflow.file.file_manager import file_manager from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout, Response from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -116,6 +119,9 @@ def _build_http_node( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, + http_client=ssrf_proxy, + tool_file_manager_factory=ToolFileManager, + file_manager=file_manager, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 5733b2cf5b..a60dde199d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -6,7 +6,6 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.entities.llm_entities import LLMUsage -from core.variables import StringSegment from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -20,6 +19,7 @@ from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import Kno from core.workflow.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables import StringSegment from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 366bec5001..63a87623da 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -5,9 +5,9 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.variables import ArrayNumberSegment, ArrayStringSegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.nodes.list_operator.node import ListOperatorNode +from core.workflow.variables import ArrayNumberSegment, ArrayStringSegment from models.workflow import WorkflowType diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index ebabf66b41..94b5b72ee1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -9,8 +9,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.model_manager import ModelInstance from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageRole, @@ -19,7 +21,7 @@ from core.model_runtime.entities.message_entities import ( ) from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.workflow.entities import GraphInitParams from core.workflow.file import File, FileTransferMethod, FileType from core.workflow.nodes.llm import llm_utils @@ -32,10 +34,11 @@ from core.workflow.nodes.llm.entities import ( VisionConfigOptions, ) from core.workflow.nodes.llm.file_saver import LLMFileSaver -from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm.node import LLMNode, _handle_memory_completion_mode from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.enums import UserFrom from models.provider import ProviderType @@ -115,6 +118,7 @@ def llm_node( graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, + model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, ) return node @@ -194,11 +198,10 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE provider_model_bundle.configuration.__class__, "get_provider_model", return_value=provider_model, + autospec=True, ), mock.patch.object( - model_type_instance.__class__, - "get_model_schema", - return_value=model_config.model_schema, + model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True ), ): fetch_model_config( @@ -585,6 +588,41 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] +def test_handle_memory_completion_mode_uses_prompt_message_interface(): + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="first question"), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ), + AssistantPromptMessage(content="first answer"), + ] + + model_instance = mock.MagicMock(spec=ModelInstance) + + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=3), + ) + + with mock.patch("core.workflow.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + + assert memory_text == "Human: first question\n[image]\nAssistant: first answer" + mock_rest_token.assert_called_once_with(prompt_messages=[], model_instance=model_instance) + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=3) + + @pytest.fixture def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]: mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) @@ -601,6 +639,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, + model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, ) return node, mock_file_saver diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py index b28d1d3d0a..2742b7dab0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py @@ -1,5 +1,5 @@ -from core.variables.types import SegmentType from core.workflow.nodes.parameter_extractor.entities import ParameterConfig +from core.workflow.variables.types import SegmentType class TestParameterConfig: diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index b359284d00..ae229bbe2e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -8,7 +8,6 @@ from typing import Any import pytest from core.model_runtime.entities import LLMMode -from core.variables.types import SegmentType from core.workflow.nodes.llm import ModelConfig, VisionConfig from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData from core.workflow.nodes.parameter_extractor.exc import ( @@ -18,6 +17,7 @@ from core.workflow.nodes.parameter_extractor.exc import ( RequiredParameterMissingError, ) from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.variables.types import SegmentType from factories.variable_factory import build_segment_with_type diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 61bdcbd250..0fb76fb7e7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -128,7 +128,8 @@ class TestTemplateTransformNode: assert TemplateTransformNode.version() == "1" @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_simple_template( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params @@ -165,7 +166,8 @@ class TestTemplateTransformNode: assert result.inputs["age"] == 30 @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with None variable values.""" @@ -192,7 +194,8 @@ class TestTemplateTransformNode: assert result.inputs["value"] is None @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_with_code_execution_error( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params @@ -215,7 +218,8 @@ class TestTemplateTransformNode: assert "Template syntax error" in result.error @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_output_length_exceeds_limit( self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params @@ -239,7 +243,8 @@ class TestTemplateTransformNode: assert "Output length exceeds" in result.error @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_with_complex_jinja2_template( self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params @@ -303,7 +308,8 @@ class TestTemplateTransformNode: assert mapping["node_123.var2"] == ["sys", "input2"] @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" @@ -330,7 +336,8 @@ class TestTemplateTransformNode: assert result.inputs == {} @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with numeric variable values.""" @@ -369,7 +376,8 @@ class TestTemplateTransformNode: assert result.outputs["output"] == "Total: $31.5" @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with dictionary variable values.""" @@ -400,7 +408,8 @@ class TestTemplateTransformNode: assert "john@example.com" in result.outputs["output"] @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template" + "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", + autospec=True, ) def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): """Test _run with list variable values.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 669f36c100..35c59b92c4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -6,9 +6,6 @@ import pytest from docx.oxml.text.paragraph import CT_P from core.app.entities.app_invoke_entities import InvokeFrom -from core.variables import ArrayFileSegment -from core.variables.segments import ArrayStringSegment -from core.variables.variables import StringVariable from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.file import File, FileTransferMethod @@ -20,6 +17,9 @@ from core.workflow.nodes.document_extractor.node import ( _extract_text_from_pdf, _extract_text_from_plain_text, ) +from core.workflow.variables import ArrayFileSegment +from core.workflow.variables.segments import ArrayStringSegment +from core.workflow.variables.variables import StringVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 930bdbda4a..bc87a64161 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -6,7 +6,6 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.app.workflow.node_factory import DifyNodeFactory -from core.variables import ArrayFileSegment from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.file import File, FileTransferMethod, FileType @@ -16,6 +15,7 @@ from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition +from core.workflow.variables import ArrayFileSegment from extensions.ext_database import db from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 66ddc0d3c7..73c17ee45a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from core.variables import ArrayFileSegment from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.file import File, FileTransferMethod, FileType from core.workflow.nodes.list_operator.entities import ( @@ -17,6 +16,7 @@ from core.workflow.nodes.list_operator.entities import ( ) from core.workflow.nodes.list_operator.exc import InvalidKeyError from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func +from core.workflow.variables import ArrayFileSegment from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 526ff72c8c..678691439f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -11,12 +11,12 @@ import pytest from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.variables.segments import ArrayFileSegment from core.workflow.entities import GraphInitParams from core.workflow.file import File, FileTransferMethod, FileType from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables.segments import ArrayFileSegment if TYPE_CHECKING: # pragma: no cover - imported for type checking only from core.workflow.nodes.tool.tool_node import ToolNode @@ -92,7 +92,9 @@ def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[lis return messages tool_runtime = MagicMock() - with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform): + with patch.object( + ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True + ): generator = tool_node._transform_message( messages=iter([message]), tool_info={"provider_type": "builtin", "provider_id": "provider"}, diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index d4b7a017f9..8a52f963ef 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -4,7 +4,6 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.workflow.node_factory import DifyNodeFactory -from core.variables import ArrayStringVariable, StringVariable from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_events.node import NodeRunSucceededEvent @@ -13,6 +12,7 @@ from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables import ArrayStringVariable, StringVariable from models.enums import UserFrom DEFAULT_NODE_ID = "node_id" diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py index 1501722b82..9a874337ed 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py @@ -1,6 +1,6 @@ -from core.variables import SegmentType from core.workflow.nodes.variable_assigner.v2.enums import Operation from core.workflow.nodes.variable_assigner.v2.helpers import is_input_value_valid +from core.workflow.variables import SegmentType def test_is_input_value_valid_overwrite_array_string(): diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index b08f9c37b4..5ed68fe8d0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -4,13 +4,13 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.workflow.node_factory import DifyNodeFactory -from core.variables import ArrayStringVariable from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables import ArrayStringVariable from models.enums import UserFrom DEFAULT_NODE_ID = "node_id" diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 8ceaad5cc9..24d3740b99 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -3,7 +3,6 @@ from unittest.mock import patch import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from core.variables import FileVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.file import File, FileTransferMethod, FileType @@ -18,6 +17,7 @@ from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.runtime.graph_runtime_state import GraphRuntimeState from core.workflow.runtime.variable_pool import VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables import FileVariable, StringVariable from models.enums import UserFrom from models.workflow import WorkflowType diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index fb9a893d43..7f2b080498 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -3,8 +3,12 @@ from collections import defaultdict import pytest -from core.variables import FileSegment, StringSegment -from core.variables.segments import ( +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.file import File, FileTransferMethod, FileType +from core.workflow.runtime import VariablePool +from core.workflow.system_variable import SystemVariable +from core.workflow.variables import FileSegment, StringSegment +from core.workflow.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -15,7 +19,7 @@ from core.variables.segments import ( NoneSegment, ObjectSegment, ) -from core.variables.variables import ( +from core.workflow.variables.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -25,10 +29,6 @@ from core.variables.variables import ( StringVariable, Variable, ) -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable from factories.variable_factory import build_segment, segment_to_variable diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 793b0d4eba..4a71692f1e 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -4,7 +4,6 @@ import pytest from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage -from core.variables.variables import StringVariable from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, @@ -15,6 +14,7 @@ from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.runtime import VariablePool from core.workflow.system_variable import SystemVariable +from core.workflow.variables.variables import StringVariable from core.workflow.workflow_entry import WorkflowEntry diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index bc55d3fccf..12b9bf5f14 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -26,11 +26,8 @@ class TestWorkflowEntryRedisChannel: redis_channel = RedisChannel(mock_redis_client, "test:channel:key") # Patch GraphEngine to verify it receives the Redis channel - with patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine: - mock_graph_engine = MagicMock() - MockGraphEngine.return_value = mock_graph_engine - - # Create WorkflowEntry with Redis channel + with patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine: + mock_graph_engine = MockGraphEngine.return_value # Create WorkflowEntry with Redis channel workflow_entry = WorkflowEntry( tenant_id="test-tenant", app_id="test-app", @@ -63,15 +60,11 @@ class TestWorkflowEntryRedisChannel: # Patch GraphEngine and InMemoryChannel with ( - patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine, - patch("core.workflow.workflow_entry.InMemoryChannel") as MockInMemoryChannel, + patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine, + patch("core.workflow.workflow_entry.InMemoryChannel", autospec=True) as MockInMemoryChannel, ): - mock_graph_engine = MagicMock() - MockGraphEngine.return_value = mock_graph_engine - mock_inmemory_channel = MagicMock() - MockInMemoryChannel.return_value = mock_inmemory_channel - - # Create WorkflowEntry without providing a channel + mock_graph_engine = MockGraphEngine.return_value + mock_inmemory_channel = MockInMemoryChannel.return_value # Create WorkflowEntry without providing a channel workflow_entry = WorkflowEntry( tenant_id="test-tenant", app_id="test-app", @@ -114,7 +107,7 @@ class TestWorkflowEntryRedisChannel: mock_event2 = MagicMock() # Patch GraphEngine - with patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine: + with patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine: mock_graph_engine = MagicMock() mock_graph_engine.run.return_value = iter([mock_event1, mock_event2]) MockGraphEngine.return_value = mock_graph_engine diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index 77c4956c04..601f2c5e3a 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -40,7 +40,7 @@ def mock_upload_file(): mock.source_url = TEST_REMOTE_URL mock.size = 1024 mock.key = "test_key" - with patch("factories.file_factory.db.session.scalar", return_value=mock) as m: + with patch("factories.file_factory.db.session.scalar", return_value=mock, autospec=True) as m: yield m @@ -54,7 +54,7 @@ def mock_tool_file(): mock.mimetype = "application/pdf" mock.original_url = "http://example.com/tool.pdf" mock.size = 2048 - with patch("factories.file_factory.db.session.scalar", return_value=mock): + with patch("factories.file_factory.db.session.scalar", return_value=mock, autospec=True): yield mock @@ -70,7 +70,7 @@ def mock_http_head(): }, ) - with patch("factories.file_factory.ssrf_proxy.head") as mock_head: + with patch("factories.file_factory.ssrf_proxy.head", autospec=True) as mock_head: mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg") yield mock_head @@ -188,7 +188,7 @@ def test_build_from_remote_url_without_strict_validation(mock_http_head): def test_tool_file_not_found(): """Test ToolFile not found in database.""" - with patch("factories.file_factory.db.session.scalar", return_value=None): + with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True): mapping = tool_file_mapping() with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) @@ -196,7 +196,7 @@ def test_tool_file_not_found(): def test_local_file_not_found(): """Test UploadFile not found in database.""" - with patch("factories.file_factory.db.session.scalar", return_value=None): + with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True): mapping = local_file_mapping() with pytest.raises(ValueError, match="Invalid upload file"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) @@ -268,7 +268,7 @@ def test_tenant_mismatch(): mock_file.key = "test_key" # Mock the database query to return None (no file found for this tenant) - with patch("factories.file_factory.db.session.scalar", return_value=None): + with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True): mapping = local_file_mapping() with pytest.raises(ValueError, match="Invalid upload file"): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 53ae18a61d..87d02cb187 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -7,7 +7,8 @@ import pytest from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from core.variables import ( +from core.workflow.file import File, FileTransferMethod, FileType +from core.workflow.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -16,8 +17,8 @@ from core.variables import ( SecretVariable, StringVariable, ) -from core.variables.exc import VariableError -from core.variables.segments import ( +from core.workflow.variables.exc import VariableError +from core.workflow.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -32,8 +33,7 @@ from core.variables.segments import ( Segment, StringSegment, ) -from core.variables.types import SegmentType -from core.workflow.file import File, FileTransferMethod, FileType +from core.workflow.variables.types import SegmentType from factories import variable_factory from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py index f84df42bfd..460374b6f6 100644 --- a/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_channel_unit_tests.py @@ -403,7 +403,7 @@ class TestRedisSubscription: # ==================== Listener Thread Tests ==================== - @patch("time.sleep", side_effect=lambda x: None) # Speed up test + @patch("time.sleep", side_effect=lambda x: None, autospec=True) # Speed up test def test_listener_thread_normal_operation( self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock ): @@ -826,7 +826,7 @@ class TestRedisShardedSubscription: # ==================== Listener Thread Tests ==================== - @patch("time.sleep", side_effect=lambda x: None) # Speed up test + @patch("time.sleep", side_effect=lambda x: None, autospec=True) # Speed up test def test_listener_thread_normal_operation( self, mock_sleep, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock ): diff --git a/api/tests/unit_tests/libs/test_datetime_utils.py b/api/tests/unit_tests/libs/test_datetime_utils.py index 84f5b63fbf..57314d29d4 100644 --- a/api/tests/unit_tests/libs/test_datetime_utils.py +++ b/api/tests/unit_tests/libs/test_datetime_utils.py @@ -104,7 +104,7 @@ class TestParseTimeRange: def test_parse_time_range_dst_ambiguous_time(self): """Test parsing during DST ambiguous time (fall back).""" # This test simulates DST fall back where 2:30 AM occurs twice - with patch("pytz.timezone") as mock_timezone: + with patch("pytz.timezone", autospec=True) as mock_timezone: # Mock timezone that raises AmbiguousTimeError mock_tz = mock_timezone.return_value @@ -135,7 +135,7 @@ class TestParseTimeRange: def test_parse_time_range_dst_nonexistent_time(self): """Test parsing during DST nonexistent time (spring forward).""" - with patch("pytz.timezone") as mock_timezone: + with patch("pytz.timezone", autospec=True) as mock_timezone: # Mock timezone that raises NonExistentTimeError mock_tz = mock_timezone.return_value diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index 35155b4931..df80428ee8 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -55,7 +55,7 @@ class TestLoginRequired: with setup_app.test_request_context(): # Mock authenticated user mock_user = MockUser("test_user", is_authenticated=True) - with patch("libs.login._get_user", return_value=mock_user): + with patch("libs.login._get_user", return_value=mock_user, autospec=True): result = protected_view() assert result == "Protected content" @@ -70,7 +70,7 @@ class TestLoginRequired: with setup_app.test_request_context(): # Mock unauthenticated user mock_user = MockUser("test_user", is_authenticated=False) - with patch("libs.login._get_user", return_value=mock_user): + with patch("libs.login._get_user", return_value=mock_user, autospec=True): result = protected_view() assert result == "Unauthorized" setup_app.login_manager.unauthorized.assert_called_once() @@ -86,8 +86,8 @@ class TestLoginRequired: with setup_app.test_request_context(): # Mock unauthenticated user and LOGIN_DISABLED mock_user = MockUser("test_user", is_authenticated=False) - with patch("libs.login._get_user", return_value=mock_user): - with patch("libs.login.dify_config") as mock_config: + with patch("libs.login._get_user", return_value=mock_user, autospec=True): + with patch("libs.login.dify_config", autospec=True) as mock_config: mock_config.LOGIN_DISABLED = True result = protected_view() @@ -106,7 +106,7 @@ class TestLoginRequired: with setup_app.test_request_context(method="OPTIONS"): # Mock unauthenticated user mock_user = MockUser("test_user", is_authenticated=False) - with patch("libs.login._get_user", return_value=mock_user): + with patch("libs.login._get_user", return_value=mock_user, autospec=True): result = protected_view() assert result == "Protected content" # Ensure unauthorized was not called @@ -125,7 +125,7 @@ class TestLoginRequired: with setup_app.test_request_context(): mock_user = MockUser("test_user", is_authenticated=True) - with patch("libs.login._get_user", return_value=mock_user): + with patch("libs.login._get_user", return_value=mock_user, autospec=True): result = protected_view() assert result == "Synced content" setup_app.ensure_sync.assert_called_once() @@ -144,7 +144,7 @@ class TestLoginRequired: with setup_app.test_request_context(): mock_user = MockUser("test_user", is_authenticated=True) - with patch("libs.login._get_user", return_value=mock_user): + with patch("libs.login._get_user", return_value=mock_user, autospec=True): result = protected_view() assert result == "Protected content" @@ -197,14 +197,14 @@ class TestCurrentUser: mock_user = MockUser("test_user", is_authenticated=True) with app.test_request_context(): - with patch("libs.login._get_user", return_value=mock_user): + with patch("libs.login._get_user", return_value=mock_user, autospec=True): assert current_user.id == "test_user" assert current_user.is_authenticated is True def test_current_user_proxy_returns_none_when_no_user(self, app: Flask): """Test that current_user proxy handles None user.""" with app.test_request_context(): - with patch("libs.login._get_user", return_value=None): + with patch("libs.login._get_user", return_value=None, autospec=True): # When _get_user returns None, accessing attributes should fail # or current_user should evaluate to falsy try: @@ -224,7 +224,7 @@ class TestCurrentUser: def check_user_in_thread(user_id: str, index: int): with app.test_request_context(): mock_user = MockUser(user_id) - with patch("libs.login._get_user", return_value=mock_user): + with patch("libs.login._get_user", return_value=mock_user, autospec=True): results[index] = current_user.id # Create multiple threads with different users diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py index b6595a8c57..bc7880ccc8 100644 --- a/api/tests/unit_tests/libs/test_oauth_clients.py +++ b/api/tests/unit_tests/libs/test_oauth_clients.py @@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest): ({}, None, True), ], ) - @patch("httpx.post") + @patch("httpx.post", autospec=True) def test_should_retrieve_access_token( self, mock_post, oauth, mock_response, response_data, expected_token, should_raise ): @@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest): ), ], ) - @patch("httpx.get") + @patch("httpx.get", autospec=True) def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email): user_response = MagicMock() user_response.json.return_value = user_data @@ -121,7 +121,7 @@ class TestGitHubOAuth(BaseOAuthTest): assert user_info.name == user_data["name"] assert user_info.email == expected_email - @patch("httpx.get") + @patch("httpx.get", autospec=True) def test_should_handle_network_errors(self, mock_get, oauth): mock_get.side_effect = httpx.RequestError("Network error") @@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest): ({}, None, True), ], ) - @patch("httpx.post") + @patch("httpx.post", autospec=True) def test_should_retrieve_access_token( self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise ): @@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest): ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string ], ) - @patch("httpx.get") + @patch("httpx.get", autospec=True) def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name): mock_response.json.return_value = user_data mock_get.return_value = mock_response @@ -222,7 +222,7 @@ class TestGoogleOAuth(BaseOAuthTest): httpx.TimeoutException, ], ) - @patch("httpx.get") + @patch("httpx.get", autospec=True) def test_should_handle_http_errors(self, mock_get, oauth, exception_type): mock_response = MagicMock() mock_response.raise_for_status.side_effect = exception_type("Error") diff --git a/api/tests/unit_tests/libs/test_smtp_client.py b/api/tests/unit_tests/libs/test_smtp_client.py index 042bc15643..1edf4899ac 100644 --- a/api/tests/unit_tests/libs/test_smtp_client.py +++ b/api/tests/unit_tests/libs/test_smtp_client.py @@ -9,11 +9,9 @@ def _mail() -> dict: return {"to": "user@example.com", "subject": "Hi", "html": "Hi"} -@patch("libs.smtp.smtplib.SMTP") +@patch("libs.smtp.smtplib.SMTP", autospec=True) def test_smtp_plain_success(mock_smtp_cls: MagicMock): - mock_smtp = MagicMock() - mock_smtp_cls.return_value = mock_smtp - + mock_smtp = mock_smtp_cls.return_value client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") client.send(_mail()) @@ -22,11 +20,9 @@ def test_smtp_plain_success(mock_smtp_cls: MagicMock): mock_smtp.quit.assert_called_once() -@patch("libs.smtp.smtplib.SMTP") +@patch("libs.smtp.smtplib.SMTP", autospec=True) def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock): - mock_smtp = MagicMock() - mock_smtp_cls.return_value = mock_smtp - + mock_smtp = mock_smtp_cls.return_value client = SMTPClient( server="smtp.example.com", port=587, @@ -46,7 +42,7 @@ def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock): mock_smtp.quit.assert_called_once() -@patch("libs.smtp.smtplib.SMTP_SSL") +@patch("libs.smtp.smtplib.SMTP_SSL", autospec=True) def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock): # Cover SMTP_SSL branch and TimeoutError handling mock_smtp = MagicMock() @@ -67,7 +63,7 @@ def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock): mock_smtp.quit.assert_called_once() -@patch("libs.smtp.smtplib.SMTP") +@patch("libs.smtp.smtplib.SMTP", autospec=True) def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock): mock_smtp = MagicMock() mock_smtp.sendmail.side_effect = RuntimeError("oops") @@ -79,7 +75,7 @@ def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock): mock_smtp.quit.assert_called_once() -@patch("libs.smtp.smtplib.SMTP") +@patch("libs.smtp.smtplib.SMTP", autospec=True) def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock): # Ensure we hit the specific SMTPException except branch import smtplib diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index c6dfd41803..8b96c62dc9 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -301,7 +301,7 @@ class TestAppModelConfig: ) # Mock database query to return None - with patch("models.model.db.session.query") as mock_query: + with patch("models.model.db.session.query", autospec=True) as mock_query: mock_query.return_value.where.return_value.first.return_value = None # Act @@ -952,7 +952,7 @@ class TestSiteModel: def test_site_generate_code(self): """Test Site.generate_code static method.""" # Mock database query to return 0 (no existing codes) - with patch("models.model.db.session.query") as mock_query: + with patch("models.model.db.session.query", autospec=True) as mock_query: mock_query.return_value.where.return_value.count.return_value = 0 # Act @@ -1167,7 +1167,7 @@ class TestConversationStatusCount: conversation.id = str(uuid4()) # Mock the database query to return no messages - with patch("models.model.db.session.scalars") as mock_scalars: + with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: mock_scalars.return_value.all.return_value = [] # Act @@ -1192,7 +1192,7 @@ class TestConversationStatusCount: conversation.id = conversation_id # Mock the database query to return no messages with workflow_run_id - with patch("models.model.db.session.scalars") as mock_scalars: + with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: mock_scalars.return_value.all.return_value = [] # Act @@ -1277,7 +1277,7 @@ class TestConversationStatusCount: return mock_result # Act & Assert - with patch("models.model.db.session.scalars", side_effect=mock_scalars): + with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True): result = conversation.status_count # Verify only 2 database queries were made (not N+1) @@ -1340,7 +1340,7 @@ class TestConversationStatusCount: return mock_result # Act - with patch("models.model.db.session.scalars", side_effect=mock_scalars): + with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True): result = conversation.status_count # Assert - query should include app_id filter @@ -1385,7 +1385,7 @@ class TestConversationStatusCount: ), ] - with patch("models.model.db.session.scalars") as mock_scalars: + with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: # Mock the messages query def mock_scalars_side_effect(query): mock_result = MagicMock() @@ -1441,7 +1441,7 @@ class TestConversationStatusCount: ), ] - with patch("models.model.db.session.scalars") as mock_scalars: + with patch("models.model.db.session.scalars", autospec=True) as mock_scalars: def mock_scalars_side_effect(query): mock_result = MagicMock() diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index 5d84a2ec85..d44aa56488 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,6 +1,6 @@ from uuid import uuid4 -from core.variables import SegmentType +from core.workflow.variables import SegmentType from factories import variable_factory from models import ConversationVariable diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 29f71767d0..544693da34 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -4,10 +4,10 @@ from unittest import mock from uuid import uuid4 from constants import HIDDEN_VALUE -from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from core.variables.segments import IntegerSegment, Segment from core.workflow.file.enums import FileTransferMethod, FileType from core.workflow.file.models import File +from core.workflow.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from core.workflow.variables.segments import IntegerSegment, Segment from factories.variable_factory import build_segment from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index a0fed1aa14..d54116555e 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -15,7 +15,7 @@ class TestTencentCos(BaseStorageTest): @pytest.fixture(autouse=True) def setup_method(self, setup_tencent_cos_mock): """Executed before each test method.""" - with patch.object(CosConfig, "__init__", return_value=None): + with patch.object(CosConfig, "__init__", return_value=None, autospec=True): self.storage = TencentCosStorage() self.storage.bucket_name = get_example_bucket() @@ -39,9 +39,9 @@ class TestTencentCosConfiguration: with ( patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config), patch( - "extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance + "extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance, autospec=True ) as mock_cos_config, - patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client), + patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client, autospec=True), ): TencentCosStorage() @@ -72,9 +72,9 @@ class TestTencentCosConfiguration: with ( patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config), patch( - "extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance + "extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance, autospec=True ) as mock_cos_config, - patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client), + patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client, autospec=True), ): TencentCosStorage() diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py index 9d9cb7c6d5..60af6e20c2 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py @@ -19,7 +19,7 @@ class TestApiKeyAuthFactory: ) def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path): """Test getting auth factory for all valid providers""" - with patch(auth_class_path) as mock_auth: + with patch(auth_class_path, autospec=True) as mock_auth: auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) assert auth_class == mock_auth @@ -46,7 +46,7 @@ class TestApiKeyAuthFactory: (False, False), ], ) - @patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory") + @patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True) def test_validate_credentials_delegates_to_auth_instance( self, mock_get_factory, credentials_return_value, expected_result ): @@ -65,7 +65,7 @@ class TestApiKeyAuthFactory: assert result is expected_result mock_auth_instance.validate_credentials.assert_called_once() - @patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory") + @patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True) def test_validate_credentials_propagates_exceptions(self, mock_get_factory): """Test that exceptions from auth instance are propagated""" # Arrange diff --git a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py index ab50d6a92c..1458180570 100644 --- a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py +++ b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py @@ -65,7 +65,7 @@ class TestFirecrawlAuth: FirecrawlAuth(credentials) assert str(exc_info.value) == expected_error - @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance): """Test successful credential validation""" mock_response = MagicMock() @@ -96,7 +96,7 @@ class TestFirecrawlAuth: (500, "Internal server error"), ], ) - @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance): """Test handling of various HTTP error codes""" mock_response = MagicMock() @@ -118,7 +118,7 @@ class TestFirecrawlAuth: (401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"), ], ) - @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) def test_should_handle_unexpected_errors( self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance ): @@ -145,7 +145,7 @@ class TestFirecrawlAuth: (httpx.ConnectTimeout, "Connection timeout"), ], ) - @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance): """Test handling of various network-related errors including timeouts""" mock_post.side_effect = exception_type(exception_message) @@ -167,7 +167,7 @@ class TestFirecrawlAuth: FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}}) assert "super_secret_key_12345" not in str(exc_info.value) - @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) def test_should_use_custom_base_url_in_validation(self, mock_post): """Test that custom base URL is used in validation and normalized""" mock_response = MagicMock() @@ -185,7 +185,7 @@ class TestFirecrawlAuth: assert result is True assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl" - @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True) def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance): """Test that timeout errors are handled gracefully with appropriate error message""" mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds") diff --git a/api/tests/unit_tests/services/auth/test_jina_auth.py b/api/tests/unit_tests/services/auth/test_jina_auth.py index 4d2f300d25..67f252390d 100644 --- a/api/tests/unit_tests/services/auth/test_jina_auth.py +++ b/api/tests/unit_tests/services/auth/test_jina_auth.py @@ -35,7 +35,7 @@ class TestJinaAuth: JinaAuth(credentials) assert str(exc_info.value) == "No API key provided" - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina.httpx.post", autospec=True) def test_should_validate_valid_credentials_successfully(self, mock_post): """Test successful credential validation""" mock_response = MagicMock() @@ -53,7 +53,7 @@ class TestJinaAuth: json={"url": "https://example.com"}, ) - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina.httpx.post", autospec=True) def test_should_handle_http_402_error(self, mock_post): """Test handling of 402 Payment Required error""" mock_response = MagicMock() @@ -68,7 +68,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required" - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina.httpx.post", autospec=True) def test_should_handle_http_409_error(self, mock_post): """Test handling of 409 Conflict error""" mock_response = MagicMock() @@ -83,7 +83,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina.httpx.post", autospec=True) def test_should_handle_http_500_error(self, mock_post): """Test handling of 500 Internal Server Error""" mock_response = MagicMock() @@ -98,7 +98,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina.httpx.post", autospec=True) def test_should_handle_unexpected_error_with_text_response(self, mock_post): """Test handling of unexpected errors with text response""" mock_response = MagicMock() @@ -114,7 +114,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden" - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina.httpx.post", autospec=True) def test_should_handle_unexpected_error_without_text(self, mock_post): """Test handling of unexpected errors without text response""" mock_response = MagicMock() @@ -130,7 +130,7 @@ class TestJinaAuth: auth.validate_credentials() assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" - @patch("services.auth.jina.jina.httpx.post") + @patch("services.auth.jina.jina.httpx.post", autospec=True) def test_should_handle_network_errors(self, mock_post): """Test handling of network connection errors""" mock_post.side_effect = httpx.ConnectError("Network error") diff --git a/api/tests/unit_tests/services/auth/test_watercrawl_auth.py b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py index ec99cb10b0..1d561731d4 100644 --- a/api/tests/unit_tests/services/auth/test_watercrawl_auth.py +++ b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py @@ -64,7 +64,7 @@ class TestWatercrawlAuth: WatercrawlAuth(credentials) assert str(exc_info.value) == expected_error - @patch("services.auth.watercrawl.watercrawl.httpx.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True) def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance): """Test successful credential validation""" mock_response = MagicMock() @@ -87,7 +87,7 @@ class TestWatercrawlAuth: (500, "Internal server error"), ], ) - @patch("services.auth.watercrawl.watercrawl.httpx.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True) def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance): """Test handling of various HTTP error codes""" mock_response = MagicMock() @@ -107,7 +107,7 @@ class TestWatercrawlAuth: (401, "Not JSON", True, "Expecting value"), # JSON decode error ], ) - @patch("services.auth.watercrawl.watercrawl.httpx.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True) def test_should_handle_unexpected_errors( self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance ): @@ -132,7 +132,7 @@ class TestWatercrawlAuth: (httpx.ConnectTimeout, "Connection timeout"), ], ) - @patch("services.auth.watercrawl.watercrawl.httpx.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True) def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance): """Test handling of various network-related errors including timeouts""" mock_get.side_effect = exception_type(exception_message) @@ -154,7 +154,7 @@ class TestWatercrawlAuth: WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}) assert "super_secret_key_12345" not in str(exc_info.value) - @patch("services.auth.watercrawl.watercrawl.httpx.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True) def test_should_use_custom_base_url_in_validation(self, mock_get): """Test that custom base URL is used in validation""" mock_response = MagicMock() @@ -179,7 +179,7 @@ class TestWatercrawlAuth: ("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), ], ) - @patch("services.auth.watercrawl.watercrawl.httpx.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True) def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url): """Test that urljoin is used correctly for URL construction with various base URLs""" mock_response = MagicMock() @@ -193,7 +193,7 @@ class TestWatercrawlAuth: # Verify the correct URL was called assert mock_get.call_args[0][0] == expected_url - @patch("services.auth.watercrawl.watercrawl.httpx.get") + @patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True) def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance): """Test that timeout errors are handled gracefully with appropriate error message""" mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds") diff --git a/api/tests/unit_tests/services/document_service_status.py b/api/tests/unit_tests/services/document_service_status.py index b83aba1171..1b682d5762 100644 --- a/api/tests/unit_tests/services/document_service_status.py +++ b/api/tests/unit_tests/services/document_service_status.py @@ -1,206 +1,16 @@ -""" -Comprehensive unit tests for DocumentService status management methods. +"""Unit tests for non-SQL validation in DocumentService status management methods.""" -This module contains extensive unit tests for the DocumentService class, -specifically focusing on document status management operations including -pause, recover, retry, batch updates, and renaming. - -The DocumentService provides methods for: -- Pausing document indexing processes (pause_document) -- Recovering documents from paused or error states (recover_document) -- Retrying failed document indexing operations (retry_document) -- Batch updating document statuses (batch_update_document_status) -- Renaming documents (rename_document) - -These operations are critical for document lifecycle management and require -careful handling of document states, indexing processes, and user permissions. - -This test suite ensures: -- Correct pause and resume of document indexing -- Proper recovery from error states -- Accurate retry mechanisms for failed operations -- Batch status updates work correctly -- Document renaming with proper validation -- State transitions are handled correctly -- Error conditions are handled gracefully - -================================================================================ -ARCHITECTURE OVERVIEW -================================================================================ - -The DocumentService status management operations are part of the document -lifecycle management system. These operations interact with multiple -components: - -1. Document States: Documents can be in various states: - - waiting: Waiting to be indexed - - parsing: Currently being parsed - - cleaning: Currently being cleaned - - splitting: Currently being split into segments - - indexing: Currently being indexed - - completed: Indexing completed successfully - - error: Indexing failed with an error - - paused: Indexing paused by user - -2. Status Flags: Documents have several status flags: - - is_paused: Whether indexing is paused - - enabled: Whether document is enabled for retrieval - - archived: Whether document is archived - - indexing_status: Current indexing status - -3. Redis Cache: Used for: - - Pause flags: Prevents concurrent pause operations - - Retry flags: Prevents concurrent retry operations - - Indexing flags: Tracks active indexing operations - -4. Task Queue: Async tasks for: - - Recovering document indexing - - Retrying document indexing - - Adding documents to index - - Removing documents from index - -5. Database: Stores document state and metadata: - - Document status fields - - Timestamps (paused_at, disabled_at, archived_at) - - User IDs (paused_by, disabled_by, archived_by) - -================================================================================ -TESTING STRATEGY -================================================================================ - -This test suite follows a comprehensive testing strategy that covers: - -1. Pause Operations: - - Pausing documents in various indexing states - - Setting pause flags in Redis - - Updating document state - - Error handling for invalid states - -2. Recovery Operations: - - Recovering paused documents - - Clearing pause flags - - Triggering recovery tasks - - Error handling for non-paused documents - -3. Retry Operations: - - Retrying failed documents - - Setting retry flags - - Resetting document status - - Preventing concurrent retries - - Triggering retry tasks - -4. Batch Status Updates: - - Enabling documents - - Disabling documents - - Archiving documents - - Unarchiving documents - - Handling empty lists - - Validating document states - - Transaction handling - -5. Rename Operations: - - Renaming documents successfully - - Validating permissions - - Updating metadata - - Updating associated files - - Error handling - -================================================================================ -""" - -import datetime -from unittest.mock import Mock, create_autospec, patch +from unittest.mock import Mock, create_autospec import pytest from models import Account -from models.dataset import Dataset, Document -from models.model import UploadFile +from models.dataset import Dataset from services.dataset_service import DocumentService -from services.errors.document import DocumentIndexingError - -# ============================================================================ -# Test Data Factory -# ============================================================================ class DocumentStatusTestDataFactory: - """ - Factory class for creating test data and mock objects for document status tests. - - This factory provides static methods to create mock objects for: - - Document instances with various status configurations - - Dataset instances - - User/Account instances - - UploadFile instances - - Redis cache keys and values - - The factory methods help maintain consistency across tests and reduce - code duplication when setting up test scenarios. - """ - - @staticmethod - def create_document_mock( - document_id: str = "document-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - name: str = "Test Document", - indexing_status: str = "completed", - is_paused: bool = False, - enabled: bool = True, - archived: bool = False, - paused_by: str | None = None, - paused_at: datetime.datetime | None = None, - data_source_type: str = "upload_file", - data_source_info: dict | None = None, - doc_metadata: dict | None = None, - **kwargs, - ) -> Mock: - """ - Create a mock Document with specified attributes. - - Args: - document_id: Unique identifier for the document - dataset_id: Dataset identifier - tenant_id: Tenant identifier - name: Document name - indexing_status: Current indexing status - is_paused: Whether document is paused - enabled: Whether document is enabled - archived: Whether document is archived - paused_by: ID of user who paused the document - paused_at: Timestamp when document was paused - data_source_type: Type of data source - data_source_info: Data source information dictionary - doc_metadata: Document metadata dictionary - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a Document instance - """ - document = Mock(spec=Document) - document.id = document_id - document.dataset_id = dataset_id - document.tenant_id = tenant_id - document.name = name - document.indexing_status = indexing_status - document.is_paused = is_paused - document.enabled = enabled - document.archived = archived - document.paused_by = paused_by - document.paused_at = paused_at - document.data_source_type = data_source_type - document.data_source_info = data_source_info or {} - document.doc_metadata = doc_metadata or {} - document.completed_at = datetime.datetime.now() if indexing_status == "completed" else None - document.position = 1 - for key, value in kwargs.items(): - setattr(document, key, value) - - # Mock data_source_info_dict property - document.data_source_info_dict = data_source_info or {} - - return document + """Factory class for creating test data and mock objects for document status tests.""" @staticmethod def create_dataset_mock( @@ -210,19 +20,7 @@ class DocumentStatusTestDataFactory: built_in_field_enabled: bool = False, **kwargs, ) -> Mock: - """ - Create a mock Dataset with specified attributes. - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier - name: Dataset name - built_in_field_enabled: Whether built-in fields are enabled - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a Dataset instance - """ + """Create a mock Dataset with specified attributes.""" dataset = Mock(spec=Dataset) dataset.id = dataset_id dataset.tenant_id = tenant_id @@ -238,17 +36,7 @@ class DocumentStatusTestDataFactory: tenant_id: str = "tenant-123", **kwargs, ) -> Mock: - """ - Create a mock user (Account) with specified attributes. - - Args: - user_id: Unique identifier for the user - tenant_id: Tenant identifier - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as an Account instance - """ + """Create a mock user (Account) with specified attributes.""" user = create_autospec(Account, instance=True) user.id = user_id user.current_tenant_id = tenant_id @@ -256,762 +44,11 @@ class DocumentStatusTestDataFactory: setattr(user, key, value) return user - @staticmethod - def create_upload_file_mock( - file_id: str = "file-123", - name: str = "test_file.pdf", - **kwargs, - ) -> Mock: - """ - Create a mock UploadFile with specified attributes. - - Args: - file_id: Unique identifier for the file - name: File name - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as an UploadFile instance - """ - upload_file = Mock(spec=UploadFile) - upload_file.id = file_id - upload_file.name = name - for key, value in kwargs.items(): - setattr(upload_file, key, value) - return upload_file - - -# ============================================================================ -# Tests for pause_document -# ============================================================================ - - -class TestDocumentServicePauseDocument: - """ - Comprehensive unit tests for DocumentService.pause_document method. - - This test class covers the document pause functionality, which allows - users to pause the indexing process for documents that are currently - being indexed. - - The pause_document method: - 1. Validates document is in a pausable state - 2. Sets is_paused flag to True - 3. Records paused_by and paused_at - 4. Commits changes to database - 5. Sets pause flag in Redis cache - - Test scenarios include: - - Pausing documents in various indexing states - - Error handling for invalid states - - Redis cache flag setting - - Current user validation - """ - - @pytest.fixture - def mock_document_service_dependencies(self): - """ - Mock document service dependencies for testing. - - Provides mocked dependencies including: - - current_user context - - Database session - - Redis client - - Current time utilities - """ - with ( - patch( - "services.dataset_service.current_user", create_autospec(Account, instance=True) - ) as mock_current_user, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, - ): - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_naive_utc_now.return_value = current_time - mock_current_user.id = "user-123" - - yield { - "current_user": mock_current_user, - "db_session": mock_db, - "redis_client": mock_redis, - "naive_utc_now": mock_naive_utc_now, - "current_time": current_time, - } - - def test_pause_document_waiting_state_success(self, mock_document_service_dependencies): - """ - Test successful pause of document in waiting state. - - Verifies that when a document is in waiting state, it can be - paused successfully. - - This test ensures: - - Document state is validated - - is_paused flag is set - - paused_by and paused_at are recorded - - Changes are committed - - Redis cache flag is set - """ - # Arrange - document = DocumentStatusTestDataFactory.create_document_mock(indexing_status="waiting", is_paused=False) - - # Act - DocumentService.pause_document(document) - - # Assert - assert document.is_paused is True - assert document.paused_by == "user-123" - assert document.paused_at == mock_document_service_dependencies["current_time"] - - # Verify database operations - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - - # Verify Redis cache flag was set - expected_cache_key = f"document_{document.id}_is_paused" - mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True") - - def test_pause_document_indexing_state_success(self, mock_document_service_dependencies): - """ - Test successful pause of document in indexing state. - - Verifies that when a document is actively being indexed, it can - be paused successfully. - - This test ensures: - - Document in indexing state can be paused - - All pause operations complete correctly - """ - # Arrange - document = DocumentStatusTestDataFactory.create_document_mock(indexing_status="indexing", is_paused=False) - - # Act - DocumentService.pause_document(document) - - # Assert - assert document.is_paused is True - assert document.paused_by == "user-123" - - def test_pause_document_parsing_state_success(self, mock_document_service_dependencies): - """ - Test successful pause of document in parsing state. - - Verifies that when a document is being parsed, it can be paused. - - This test ensures: - - Document in parsing state can be paused - - Pause operations work for all valid states - """ - # Arrange - document = DocumentStatusTestDataFactory.create_document_mock(indexing_status="parsing", is_paused=False) - - # Act - DocumentService.pause_document(document) - - # Assert - assert document.is_paused is True - - def test_pause_document_completed_state_error(self, mock_document_service_dependencies): - """ - Test error when trying to pause completed document. - - Verifies that when a document is already completed, it cannot - be paused and a DocumentIndexingError is raised. - - This test ensures: - - Completed documents cannot be paused - - Error type is correct - - No database operations are performed - """ - # Arrange - document = DocumentStatusTestDataFactory.create_document_mock(indexing_status="completed", is_paused=False) - - # Act & Assert - with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) - - # Verify no database operations were performed - mock_document_service_dependencies["db_session"].add.assert_not_called() - mock_document_service_dependencies["db_session"].commit.assert_not_called() - - def test_pause_document_error_state_error(self, mock_document_service_dependencies): - """ - Test error when trying to pause document in error state. - - Verifies that when a document is in error state, it cannot be - paused and a DocumentIndexingError is raised. - - This test ensures: - - Error state documents cannot be paused - - Error type is correct - - No database operations are performed - """ - # Arrange - document = DocumentStatusTestDataFactory.create_document_mock(indexing_status="error", is_paused=False) - - # Act & Assert - with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) - - -# ============================================================================ -# Tests for recover_document -# ============================================================================ - - -class TestDocumentServiceRecoverDocument: - """ - Comprehensive unit tests for DocumentService.recover_document method. - - This test class covers the document recovery functionality, which allows - users to resume indexing for documents that were previously paused. - - The recover_document method: - 1. Validates document is paused - 2. Clears is_paused flag - 3. Clears paused_by and paused_at - 4. Commits changes to database - 5. Deletes pause flag from Redis cache - 6. Triggers recovery task - - Test scenarios include: - - Recovering paused documents - - Error handling for non-paused documents - - Redis cache flag deletion - - Recovery task triggering - """ - - @pytest.fixture - def mock_document_service_dependencies(self): - """ - Mock document service dependencies for testing. - - Provides mocked dependencies including: - - Database session - - Redis client - - Recovery task - """ - with ( - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.recover_document_indexing_task") as mock_task, - ): - yield { - "db_session": mock_db, - "redis_client": mock_redis, - "recover_task": mock_task, - } - - def test_recover_document_paused_success(self, mock_document_service_dependencies): - """ - Test successful recovery of paused document. - - Verifies that when a document is paused, it can be recovered - successfully and indexing resumes. - - This test ensures: - - Document is validated as paused - - is_paused flag is cleared - - paused_by and paused_at are cleared - - Changes are committed - - Redis cache flag is deleted - - Recovery task is triggered - """ - # Arrange - paused_time = datetime.datetime.now() - document = DocumentStatusTestDataFactory.create_document_mock( - indexing_status="indexing", - is_paused=True, - paused_by="user-123", - paused_at=paused_time, - ) - - # Act - DocumentService.recover_document(document) - - # Assert - assert document.is_paused is False - assert document.paused_by is None - assert document.paused_at is None - - # Verify database operations - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - - # Verify Redis cache flag was deleted - expected_cache_key = f"document_{document.id}_is_paused" - mock_document_service_dependencies["redis_client"].delete.assert_called_once_with(expected_cache_key) - - # Verify recovery task was triggered - mock_document_service_dependencies["recover_task"].delay.assert_called_once_with( - document.dataset_id, document.id - ) - - def test_recover_document_not_paused_error(self, mock_document_service_dependencies): - """ - Test error when trying to recover non-paused document. - - Verifies that when a document is not paused, it cannot be - recovered and a DocumentIndexingError is raised. - - This test ensures: - - Non-paused documents cannot be recovered - - Error type is correct - - No database operations are performed - """ - # Arrange - document = DocumentStatusTestDataFactory.create_document_mock(indexing_status="indexing", is_paused=False) - - # Act & Assert - with pytest.raises(DocumentIndexingError): - DocumentService.recover_document(document) - - # Verify no database operations were performed - mock_document_service_dependencies["db_session"].add.assert_not_called() - mock_document_service_dependencies["db_session"].commit.assert_not_called() - - -# ============================================================================ -# Tests for retry_document -# ============================================================================ - - -class TestDocumentServiceRetryDocument: - """ - Comprehensive unit tests for DocumentService.retry_document method. - - This test class covers the document retry functionality, which allows - users to retry failed document indexing operations. - - The retry_document method: - 1. Validates documents are not already being retried - 2. Sets retry flag in Redis cache - 3. Resets document indexing_status to waiting - 4. Commits changes to database - 5. Triggers retry task - - Test scenarios include: - - Retrying single document - - Retrying multiple documents - - Error handling for concurrent retries - - Current user validation - - Retry task triggering - """ - - @pytest.fixture - def mock_document_service_dependencies(self): - """ - Mock document service dependencies for testing. - - Provides mocked dependencies including: - - current_user context - - Database session - - Redis client - - Retry task - """ - with ( - patch( - "services.dataset_service.current_user", create_autospec(Account, instance=True) - ) as mock_current_user, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.retry_document_indexing_task") as mock_task, - ): - mock_current_user.id = "user-123" - - yield { - "current_user": mock_current_user, - "db_session": mock_db, - "redis_client": mock_redis, - "retry_task": mock_task, - } - - def test_retry_document_single_success(self, mock_document_service_dependencies): - """ - Test successful retry of single document. - - Verifies that when a document is retried, the retry process - completes successfully. - - This test ensures: - - Retry flag is checked - - Document status is reset to waiting - - Changes are committed - - Retry flag is set in Redis - - Retry task is triggered - """ - # Arrange - dataset_id = "dataset-123" - document = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-123", - dataset_id=dataset_id, - indexing_status="error", - ) - - # Mock Redis to return None (not retrying) - mock_document_service_dependencies["redis_client"].get.return_value = None - - # Act - DocumentService.retry_document(dataset_id, [document]) - - # Assert - assert document.indexing_status == "waiting" - - # Verify database operations - mock_document_service_dependencies["db_session"].add.assert_called_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called() - - # Verify retry flag was set - expected_cache_key = f"document_{document.id}_is_retried" - mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1) - - # Verify retry task was triggered - mock_document_service_dependencies["retry_task"].delay.assert_called_once_with( - dataset_id, [document.id], "user-123" - ) - - def test_retry_document_multiple_success(self, mock_document_service_dependencies): - """ - Test successful retry of multiple documents. - - Verifies that when multiple documents are retried, all retry - processes complete successfully. - - This test ensures: - - Multiple documents can be retried - - All documents are processed - - Retry task is triggered with all document IDs - """ - # Arrange - dataset_id = "dataset-123" - document1 = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-123", dataset_id=dataset_id, indexing_status="error" - ) - document2 = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-456", dataset_id=dataset_id, indexing_status="error" - ) - - # Mock Redis to return None (not retrying) - mock_document_service_dependencies["redis_client"].get.return_value = None - - # Act - DocumentService.retry_document(dataset_id, [document1, document2]) - - # Assert - assert document1.indexing_status == "waiting" - assert document2.indexing_status == "waiting" - - # Verify retry task was triggered with all document IDs - mock_document_service_dependencies["retry_task"].delay.assert_called_once_with( - dataset_id, [document1.id, document2.id], "user-123" - ) - - def test_retry_document_concurrent_retry_error(self, mock_document_service_dependencies): - """ - Test error when document is already being retried. - - Verifies that when a document is already being retried, a new - retry attempt raises a ValueError. - - This test ensures: - - Concurrent retries are prevented - - Error message is clear - - Error type is correct - """ - # Arrange - dataset_id = "dataset-123" - document = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-123", dataset_id=dataset_id, indexing_status="error" - ) - - # Mock Redis to return retry flag (already retrying) - mock_document_service_dependencies["redis_client"].get.return_value = "1" - - # Act & Assert - with pytest.raises(ValueError, match="Document is being retried, please try again later"): - DocumentService.retry_document(dataset_id, [document]) - - # Verify no database operations were performed - mock_document_service_dependencies["db_session"].add.assert_not_called() - mock_document_service_dependencies["db_session"].commit.assert_not_called() - - def test_retry_document_missing_current_user_error(self, mock_document_service_dependencies): - """ - Test error when current_user is missing. - - Verifies that when current_user is None or has no ID, a ValueError - is raised. - - This test ensures: - - Current user validation works correctly - - Error message is clear - - Error type is correct - """ - # Arrange - dataset_id = "dataset-123" - document = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-123", dataset_id=dataset_id, indexing_status="error" - ) - - # Mock Redis to return None (not retrying) - mock_document_service_dependencies["redis_client"].get.return_value = None - - # Mock current_user to be None - mock_document_service_dependencies["current_user"].id = None - - # Act & Assert - with pytest.raises(ValueError, match="Current user or current user id not found"): - DocumentService.retry_document(dataset_id, [document]) - - -# ============================================================================ -# Tests for batch_update_document_status -# ============================================================================ - class TestDocumentServiceBatchUpdateDocumentStatus: - """ - Comprehensive unit tests for DocumentService.batch_update_document_status method. + """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" - This test class covers the batch document status update functionality, - which allows users to update the status of multiple documents at once. - - The batch_update_document_status method: - 1. Validates action parameter - 2. Validates all documents - 3. Checks if documents are being indexed - 4. Prepares updates for each document - 5. Applies all updates in a single transaction - 6. Triggers async tasks - 7. Sets Redis cache flags - - Test scenarios include: - - Batch enabling documents - - Batch disabling documents - - Batch archiving documents - - Batch unarchiving documents - - Handling empty lists - - Invalid action handling - - Document indexing check - - Transaction rollback on errors - """ - - @pytest.fixture - def mock_document_service_dependencies(self): - """ - Mock document service dependencies for testing. - - Provides mocked dependencies including: - - get_document method - - Database session - - Redis client - - Async tasks - """ - with ( - patch("services.dataset_service.DocumentService.get_document") as mock_get_document, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.add_document_to_index_task") as mock_add_task, - patch("services.dataset_service.remove_document_from_index_task") as mock_remove_task, - patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, - ): - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_naive_utc_now.return_value = current_time - - yield { - "get_document": mock_get_document, - "db_session": mock_db, - "redis_client": mock_redis, - "add_task": mock_add_task, - "remove_task": mock_remove_task, - "naive_utc_now": mock_naive_utc_now, - "current_time": current_time, - } - - def test_batch_update_document_status_enable_success(self, mock_document_service_dependencies): - """ - Test successful batch enabling of documents. - - Verifies that when documents are enabled in batch, all operations - complete successfully. - - This test ensures: - - Documents are retrieved correctly - - Enabled flag is set - - Async tasks are triggered - - Redis cache flags are set - - Transaction is committed - """ - # Arrange - dataset = DocumentStatusTestDataFactory.create_dataset_mock() - user = DocumentStatusTestDataFactory.create_user_mock() - document_ids = ["document-123", "document-456"] - - document1 = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-123", enabled=False, indexing_status="completed" - ) - document2 = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-456", enabled=False, indexing_status="completed" - ) - - mock_document_service_dependencies["get_document"].side_effect = [document1, document2] - mock_document_service_dependencies["redis_client"].get.return_value = None # Not indexing - - # Act - DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) - - # Assert - assert document1.enabled is True - assert document2.enabled is True - - # Verify database operations - mock_document_service_dependencies["db_session"].add.assert_called() - mock_document_service_dependencies["db_session"].commit.assert_called_once() - - # Verify async tasks were triggered - assert mock_document_service_dependencies["add_task"].delay.call_count == 2 - - def test_batch_update_document_status_disable_success(self, mock_document_service_dependencies): - """ - Test successful batch disabling of documents. - - Verifies that when documents are disabled in batch, all operations - complete successfully. - - This test ensures: - - Documents are retrieved correctly - - Enabled flag is cleared - - Disabled_at and disabled_by are set - - Async tasks are triggered - - Transaction is committed - """ - # Arrange - dataset = DocumentStatusTestDataFactory.create_dataset_mock() - user = DocumentStatusTestDataFactory.create_user_mock(user_id="user-123") - document_ids = ["document-123"] - - document = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-123", - enabled=True, - indexing_status="completed", - completed_at=datetime.datetime.now(), - ) - - mock_document_service_dependencies["get_document"].return_value = document - mock_document_service_dependencies["redis_client"].get.return_value = None # Not indexing - - # Act - DocumentService.batch_update_document_status(dataset, document_ids, "disable", user) - - # Assert - assert document.enabled is False - assert document.disabled_at == mock_document_service_dependencies["current_time"] - assert document.disabled_by == "user-123" - - # Verify async task was triggered - mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) - - def test_batch_update_document_status_archive_success(self, mock_document_service_dependencies): - """ - Test successful batch archiving of documents. - - Verifies that when documents are archived in batch, all operations - complete successfully. - - This test ensures: - - Documents are retrieved correctly - - Archived flag is set - - Archived_at and archived_by are set - - Async tasks are triggered for enabled documents - - Transaction is committed - """ - # Arrange - dataset = DocumentStatusTestDataFactory.create_dataset_mock() - user = DocumentStatusTestDataFactory.create_user_mock(user_id="user-123") - document_ids = ["document-123"] - - document = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-123", archived=False, enabled=True - ) - - mock_document_service_dependencies["get_document"].return_value = document - mock_document_service_dependencies["redis_client"].get.return_value = None # Not indexing - - # Act - DocumentService.batch_update_document_status(dataset, document_ids, "archive", user) - - # Assert - assert document.archived is True - assert document.archived_at == mock_document_service_dependencies["current_time"] - assert document.archived_by == "user-123" - - # Verify async task was triggered for enabled document - mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) - - def test_batch_update_document_status_unarchive_success(self, mock_document_service_dependencies): - """ - Test successful batch unarchiving of documents. - - Verifies that when documents are unarchived in batch, all operations - complete successfully. - - This test ensures: - - Documents are retrieved correctly - - Archived flag is cleared - - Archived_at and archived_by are cleared - - Async tasks are triggered for enabled documents - - Transaction is committed - """ - # Arrange - dataset = DocumentStatusTestDataFactory.create_dataset_mock() - user = DocumentStatusTestDataFactory.create_user_mock() - document_ids = ["document-123"] - - document = DocumentStatusTestDataFactory.create_document_mock( - document_id="document-123", archived=True, enabled=True - ) - - mock_document_service_dependencies["get_document"].return_value = document - mock_document_service_dependencies["redis_client"].get.return_value = None # Not indexing - - # Act - DocumentService.batch_update_document_status(dataset, document_ids, "un_archive", user) - - # Assert - assert document.archived is False - assert document.archived_at is None - assert document.archived_by is None - - # Verify async task was triggered for enabled document - mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id) - - def test_batch_update_document_status_empty_list(self, mock_document_service_dependencies): - """ - Test handling of empty document list. - - Verifies that when an empty list is provided, the method returns - early without performing any operations. - - This test ensures: - - Empty lists are handled gracefully - - No database operations are performed - - No errors are raised - """ - # Arrange - dataset = DocumentStatusTestDataFactory.create_dataset_mock() - user = DocumentStatusTestDataFactory.create_user_mock() - document_ids = [] - - # Act - DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) - - # Assert - # Verify no database operations were performed - mock_document_service_dependencies["db_session"].add.assert_not_called() - mock_document_service_dependencies["db_session"].commit.assert_not_called() - - def test_batch_update_document_status_invalid_action_error(self, mock_document_service_dependencies): + def test_batch_update_document_status_invalid_action_error(self): """ Test error handling for invalid action. @@ -1031,285 +68,3 @@ class TestDocumentServiceBatchUpdateDocumentStatus: # Act & Assert with pytest.raises(ValueError, match="Invalid action"): DocumentService.batch_update_document_status(dataset, document_ids, "invalid_action", user) - - def test_batch_update_document_status_document_indexing_error(self, mock_document_service_dependencies): - """ - Test error when document is being indexed. - - Verifies that when a document is currently being indexed, a - DocumentIndexingError is raised. - - This test ensures: - - Indexing documents cannot be updated - - Error message is clear - - Error type is correct - """ - # Arrange - dataset = DocumentStatusTestDataFactory.create_dataset_mock() - user = DocumentStatusTestDataFactory.create_user_mock() - document_ids = ["document-123"] - - document = DocumentStatusTestDataFactory.create_document_mock(document_id="document-123") - - mock_document_service_dependencies["get_document"].return_value = document - mock_document_service_dependencies["redis_client"].get.return_value = "1" # Currently indexing - - # Act & Assert - with pytest.raises(DocumentIndexingError, match="is being indexed"): - DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) - - -# ============================================================================ -# Tests for rename_document -# ============================================================================ - - -class TestDocumentServiceRenameDocument: - """ - Comprehensive unit tests for DocumentService.rename_document method. - - This test class covers the document renaming functionality, which allows - users to rename documents for better organization. - - The rename_document method: - 1. Validates dataset exists - 2. Validates document exists - 3. Validates tenant permission - 4. Updates document name - 5. Updates metadata if built-in fields enabled - 6. Updates associated upload file name - 7. Commits changes - - Test scenarios include: - - Successful document renaming - - Dataset not found error - - Document not found error - - Permission validation - - Metadata updates - - Upload file name updates - """ - - @pytest.fixture - def mock_document_service_dependencies(self): - """ - Mock document service dependencies for testing. - - Provides mocked dependencies including: - - DatasetService.get_dataset - - DocumentService.get_document - - current_user context - - Database session - """ - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.DocumentService.get_document") as mock_get_document, - patch( - "services.dataset_service.current_user", create_autospec(Account, instance=True) - ) as mock_current_user, - patch("extensions.ext_database.db.session") as mock_db, - ): - mock_current_user.current_tenant_id = "tenant-123" - - yield { - "get_dataset": mock_get_dataset, - "get_document": mock_get_document, - "current_user": mock_current_user, - "db_session": mock_db, - } - - def test_rename_document_success(self, mock_document_service_dependencies): - """ - Test successful document renaming. - - Verifies that when all validation passes, a document is renamed - successfully. - - This test ensures: - - Dataset is retrieved correctly - - Document is retrieved correctly - - Document name is updated - - Changes are committed - """ - # Arrange - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "New Document Name" - - dataset = DocumentStatusTestDataFactory.create_dataset_mock(dataset_id=dataset_id) - document = DocumentStatusTestDataFactory.create_document_mock( - document_id=document_id, dataset_id=dataset_id, tenant_id="tenant-123" - ) - - mock_document_service_dependencies["get_dataset"].return_value = dataset - mock_document_service_dependencies["get_document"].return_value = document - - # Act - result = DocumentService.rename_document(dataset_id, document_id, new_name) - - # Assert - assert result == document - assert document.name == new_name - - # Verify database operations - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - - def test_rename_document_with_built_in_fields(self, mock_document_service_dependencies): - """ - Test document renaming with built-in fields enabled. - - Verifies that when built-in fields are enabled, the document - metadata is also updated. - - This test ensures: - - Document name is updated - - Metadata is updated with new name - - Built-in field is set correctly - """ - # Arrange - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "New Document Name" - - dataset = DocumentStatusTestDataFactory.create_dataset_mock(dataset_id=dataset_id, built_in_field_enabled=True) - document = DocumentStatusTestDataFactory.create_document_mock( - document_id=document_id, - dataset_id=dataset_id, - tenant_id="tenant-123", - doc_metadata={"existing_key": "existing_value"}, - ) - - mock_document_service_dependencies["get_dataset"].return_value = dataset - mock_document_service_dependencies["get_document"].return_value = document - - # Act - DocumentService.rename_document(dataset_id, document_id, new_name) - - # Assert - assert document.name == new_name - assert "document_name" in document.doc_metadata - assert document.doc_metadata["document_name"] == new_name - assert document.doc_metadata["existing_key"] == "existing_value" # Existing metadata preserved - - def test_rename_document_with_upload_file(self, mock_document_service_dependencies): - """ - Test document renaming with associated upload file. - - Verifies that when a document has an associated upload file, - the file name is also updated. - - This test ensures: - - Document name is updated - - Upload file name is updated - - Database query is executed correctly - """ - # Arrange - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "New Document Name" - file_id = "file-123" - - dataset = DocumentStatusTestDataFactory.create_dataset_mock(dataset_id=dataset_id) - document = DocumentStatusTestDataFactory.create_document_mock( - document_id=document_id, - dataset_id=dataset_id, - tenant_id="tenant-123", - data_source_info={"upload_file_id": file_id}, - ) - - mock_document_service_dependencies["get_dataset"].return_value = dataset - mock_document_service_dependencies["get_document"].return_value = document - - # Mock upload file query - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.update.return_value = None - mock_document_service_dependencies["db_session"].query.return_value = mock_query - - # Act - DocumentService.rename_document(dataset_id, document_id, new_name) - - # Assert - assert document.name == new_name - - # Verify upload file query was executed - mock_document_service_dependencies["db_session"].query.assert_called() - - def test_rename_document_dataset_not_found_error(self, mock_document_service_dependencies): - """ - Test error when dataset is not found. - - Verifies that when the dataset ID doesn't exist, a ValueError - is raised. - - This test ensures: - - Dataset existence is validated - - Error message is clear - - Error type is correct - """ - # Arrange - dataset_id = "non-existent-dataset" - document_id = "document-123" - new_name = "New Document Name" - - mock_document_service_dependencies["get_dataset"].return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="Dataset not found"): - DocumentService.rename_document(dataset_id, document_id, new_name) - - def test_rename_document_not_found_error(self, mock_document_service_dependencies): - """ - Test error when document is not found. - - Verifies that when the document ID doesn't exist, a ValueError - is raised. - - This test ensures: - - Document existence is validated - - Error message is clear - - Error type is correct - """ - # Arrange - dataset_id = "dataset-123" - document_id = "non-existent-document" - new_name = "New Document Name" - - dataset = DocumentStatusTestDataFactory.create_dataset_mock(dataset_id=dataset_id) - mock_document_service_dependencies["get_dataset"].return_value = dataset - mock_document_service_dependencies["get_document"].return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="Document not found"): - DocumentService.rename_document(dataset_id, document_id, new_name) - - def test_rename_document_permission_error(self, mock_document_service_dependencies): - """ - Test error when user lacks permission. - - Verifies that when the user is in a different tenant, a ValueError - is raised. - - This test ensures: - - Tenant permission is validated - - Error message is clear - - Error type is correct - """ - # Arrange - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "New Document Name" - - dataset = DocumentStatusTestDataFactory.create_dataset_mock(dataset_id=dataset_id) - document = DocumentStatusTestDataFactory.create_document_mock( - document_id=document_id, - dataset_id=dataset_id, - tenant_id="tenant-456", # Different tenant - ) - - mock_document_service_dependencies["get_dataset"].return_value = dataset - mock_document_service_dependencies["get_document"].return_value = document - - # Act & Assert - with pytest.raises(ValueError, match="No permission"): - DocumentService.rename_document(dataset_id, document_id, new_name) diff --git a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py new file mode 100644 index 0000000000..03c4f793cf --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py @@ -0,0 +1,141 @@ +"""Unit tests for enterprise service integrations. + +This module covers the enterprise-only default workspace auto-join behavior: +- Enterprise mode disabled: no external calls +- Successful join / skipped join: no errors +- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise +""" + +from unittest.mock import patch + +import pytest + +from services.enterprise.enterprise_service import ( + DefaultWorkspaceJoinResult, + EnterpriseService, + try_join_default_workspace, +) + + +class TestJoinDefaultWorkspace: + def test_join_default_workspace_success(self): + account_id = "11111111-1111-1111-1111-111111111111" + response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"} + + with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request: + mock_send_request.return_value = response + + result = EnterpriseService.join_default_workspace(account_id=account_id) + + assert isinstance(result, DefaultWorkspaceJoinResult) + assert result.workspace_id == response["workspace_id"] + assert result.joined is True + assert result.message == "ok" + + mock_send_request.assert_called_once_with( + "POST", + "/default-workspace/members", + json={"account_id": account_id}, + timeout=1.0, + raise_for_status=True, + ) + + def test_join_default_workspace_invalid_response_format_raises(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request: + mock_send_request.return_value = "not-a-dict" + + with pytest.raises(ValueError, match="Invalid response format"): + EnterpriseService.join_default_workspace(account_id=account_id) + + def test_join_default_workspace_invalid_account_id_raises(self): + with pytest.raises(ValueError): + EnterpriseService.join_default_workspace(account_id="not-a-uuid") + + def test_join_default_workspace_missing_required_fields_raises(self): + account_id = "11111111-1111-1111-1111-111111111111" + response = {"workspace_id": "", "message": "ok"} # missing "joined" + + with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request: + mock_send_request.return_value = response + + with pytest.raises(ValueError, match="Invalid response payload"): + EnterpriseService.join_default_workspace(account_id=account_id) + + def test_join_default_workspace_joined_without_workspace_id_raises(self): + with pytest.raises(ValueError, match="workspace_id must be non-empty when joined is True"): + DefaultWorkspaceJoinResult(workspace_id="", joined=True, message="ok") + + +class TestTryJoinDefaultWorkspace: + def test_try_join_default_workspace_enterprise_disabled_noop(self): + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = False + + try_join_default_workspace("11111111-1111-1111-1111-111111111111") + + mock_join.assert_not_called() + + def test_try_join_default_workspace_successful_join_does_not_raise(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.return_value = DefaultWorkspaceJoinResult( + workspace_id="22222222-2222-2222-2222-222222222222", + joined=True, + message="ok", + ) + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_skipped_join_does_not_raise(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.return_value = DefaultWorkspaceJoinResult( + workspace_id="", + joined=False, + message="no default workspace configured", + ) + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_api_failure_soft_fails(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.side_effect = Exception("network failure") + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_invalid_account_id_soft_fails(self): + with patch("services.enterprise.enterprise_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Should not raise even though UUID parsing fails inside join_default_workspace + try_join_default_workspace("not-a-uuid") diff --git a/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py b/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py index 87c03f13a3..a98a9e97e2 100644 --- a/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py +++ b/api/tests/unit_tests/services/enterprise/test_traceparent_propagation.py @@ -27,7 +27,7 @@ class TestTraceparentPropagation: @pytest.fixture def mock_httpx_client(self): """Mock httpx.Client for testing.""" - with patch("services.enterprise.base.httpx.Client") as mock_client_class: + with patch("services.enterprise.base.httpx.Client", autospec=True) as mock_client_class: mock_client = MagicMock() mock_client_class.return_value.__enter__.return_value = mock_client mock_client_class.return_value.__exit__.return_value = None @@ -44,7 +44,9 @@ class TestTraceparentPropagation: # Arrange expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01" - with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent): + with patch( + "services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent, autospec=True + ): # Act EnterpriseRequest.send_request("GET", "/test") diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index 1647eb3e85..57364142ad 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -135,8 +135,8 @@ class TestExternalDatasetServiceGetExternalKnowledgeApis: """ with ( - patch("services.external_knowledge_service.db.paginate") as mock_paginate, - patch("services.external_knowledge_service.select"), + patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate, + patch("services.external_knowledge_service.select", autospec=True), ): yield mock_paginate @@ -245,7 +245,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi: Patch ``db.session`` for all CRUD tests in this class. """ - with patch("services.external_knowledge_service.db.session") as mock_session: + with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: yield mock_session def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock): @@ -263,7 +263,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi: } # We do not want to actually call the remote endpoint here, so we patch the validator. - with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check: + with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check: result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args) assert isinstance(result, ExternalKnowledgeApis) @@ -386,7 +386,7 @@ class TestExternalDatasetServiceUsageAndBindings: @pytest.fixture def mock_db_session(self): - with patch("services.external_knowledge_service.db.session") as mock_session: + with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: yield mock_session def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock): @@ -447,7 +447,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate: @pytest.fixture def mock_db_session(self): - with patch("services.external_knowledge_service.db.session") as mock_session: + with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: yield mock_session def test_document_create_args_validate_success(self, mock_db_session: MagicMock): @@ -520,7 +520,7 @@ class TestExternalDatasetServiceProcessExternalApi: fake_response = httpx.Response(200) - with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post: + with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post: mock_post.return_value = fake_response result = ExternalDatasetService.process_external_api(settings, files=None) @@ -681,7 +681,7 @@ class TestExternalDatasetServiceCreateExternalDataset: @pytest.fixture def mock_db_session(self): - with patch("services.external_knowledge_service.db.session") as mock_session: + with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: yield mock_session def test_create_external_dataset_success(self, mock_db_session: MagicMock): @@ -801,7 +801,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval: @pytest.fixture def mock_db_session(self): - with patch("services.external_knowledge_service.db.session") as mock_session: + with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session: yield mock_session def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock): @@ -838,7 +838,9 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval: metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"}) - with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process: + with patch.object( + ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True + ) as mock_process: result = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=tenant_id, dataset_id=dataset_id, @@ -908,7 +910,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval: fake_response.status_code = 500 fake_response.json.return_value = {} - with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response): + with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True): result = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id="tenant-1", dataset_id="ds-1", diff --git a/api/tests/unit_tests/services/hit_service.py b/api/tests/unit_tests/services/hit_service.py index 17f3a7e94e..22ab8503df 100644 --- a/api/tests/unit_tests/services/hit_service.py +++ b/api/tests/unit_tests/services/hit_service.py @@ -146,7 +146,7 @@ class TestHitTestingServiceRetrieve: Provides a mocked database session for testing database operations like adding and committing DatasetQuery records. """ - with patch("services.hit_testing_service.db.session") as mock_db: + with patch("services.hit_testing_service.db.session", autospec=True) as mock_db: yield mock_db def test_retrieve_success_with_default_retrieval_model(self, mock_db_session): @@ -174,9 +174,11 @@ class TestHitTestingServiceRetrieve: ] with ( - patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve, - patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format, - patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter, + patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve, + patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format, + patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter, ): mock_perf_counter.side_effect = [0.0, 0.1] # start, end mock_retrieve.return_value = documents @@ -218,9 +220,11 @@ class TestHitTestingServiceRetrieve: mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()] with ( - patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve, - patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format, - patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter, + patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve, + patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format, + patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter, ): mock_perf_counter.side_effect = [0.0, 0.1] mock_retrieve.return_value = documents @@ -268,10 +272,12 @@ class TestHitTestingServiceRetrieve: mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()] with ( - patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve, - patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format, - patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class, - patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter, + patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve, + patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format, + patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class, + patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter, ): mock_perf_counter.side_effect = [0.0, 0.1] mock_dataset_retrieval_class.return_value = mock_dataset_retrieval @@ -311,8 +317,10 @@ class TestHitTestingServiceRetrieve: mock_dataset_retrieval.get_metadata_filter_condition.return_value = ({}, True) with ( - patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class, - patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format, + patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class, + patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format, ): mock_dataset_retrieval_class.return_value = mock_dataset_retrieval mock_format.return_value = [] @@ -346,9 +354,11 @@ class TestHitTestingServiceRetrieve: mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()] with ( - patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve, - patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format, - patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter, + patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve, + patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format, + patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter, ): mock_perf_counter.side_effect = [0.0, 0.1] mock_retrieve.return_value = documents @@ -380,7 +390,7 @@ class TestHitTestingServiceExternalRetrieve: Provides a mocked database session for testing database operations like adding and committing DatasetQuery records. """ - with patch("services.hit_testing_service.db.session") as mock_db: + with patch("services.hit_testing_service.db.session", autospec=True) as mock_db: yield mock_db def test_external_retrieve_success(self, mock_db_session): @@ -403,8 +413,10 @@ class TestHitTestingServiceExternalRetrieve: ] with ( - patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve, - patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter, + patch( + "services.hit_testing_service.RetrievalService.external_retrieve", autospec=True + ) as mock_external_retrieve, + patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter, ): mock_perf_counter.side_effect = [0.0, 0.1] mock_external_retrieve.return_value = external_documents @@ -467,8 +479,10 @@ class TestHitTestingServiceExternalRetrieve: external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}] with ( - patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve, - patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter, + patch( + "services.hit_testing_service.RetrievalService.external_retrieve", autospec=True + ) as mock_external_retrieve, + patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter, ): mock_perf_counter.side_effect = [0.0, 0.1] mock_external_retrieve.return_value = external_documents @@ -499,8 +513,10 @@ class TestHitTestingServiceExternalRetrieve: metadata_filtering_conditions = {} with ( - patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve, - patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter, + patch( + "services.hit_testing_service.RetrievalService.external_retrieve", autospec=True + ) as mock_external_retrieve, + patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter, ): mock_perf_counter.side_effect = [0.0, 0.1] mock_external_retrieve.return_value = [] @@ -542,7 +558,9 @@ class TestHitTestingServiceCompactRetrieveResponse: HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2", score=0.85), ] - with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format: + with patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format: mock_format.return_value = mock_records # Act @@ -566,7 +584,9 @@ class TestHitTestingServiceCompactRetrieveResponse: query = "test query" documents = [] - with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format: + with patch( + "services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True + ) as mock_format: mock_format.return_value = [] # Act diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index ee05e890b2..affbc8d0b5 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -147,7 +147,7 @@ class TestSegmentServiceCreateSegment: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db @pytest.fixture @@ -172,10 +172,12 @@ class TestSegmentServiceCreateSegment: mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment with ( - patch("services.dataset_service.redis_client.lock") as mock_lock, - patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash") as mock_hash, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, + patch( + "services.dataset_service.VectorService.create_segments_vector", autospec=True + ) as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_lock.return_value.__enter__ = Mock() mock_lock.return_value.__exit__ = Mock(return_value=None) @@ -219,10 +221,12 @@ class TestSegmentServiceCreateSegment: mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment with ( - patch("services.dataset_service.redis_client.lock") as mock_lock, - patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash") as mock_hash, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, + patch( + "services.dataset_service.VectorService.create_segments_vector", autospec=True + ) as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_lock.return_value.__enter__ = Mock() mock_lock.return_value.__exit__ = Mock(return_value=None) @@ -257,11 +261,13 @@ class TestSegmentServiceCreateSegment: mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment with ( - patch("services.dataset_service.redis_client.lock") as mock_lock, - patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service, - patch("services.dataset_service.ModelManager") as mock_model_manager_class, - patch("services.dataset_service.helper.generate_text_hash") as mock_hash, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, + patch( + "services.dataset_service.VectorService.create_segments_vector", autospec=True + ) as mock_vector_service, + patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class, + patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_lock.return_value.__enter__ = Mock() mock_lock.return_value.__exit__ = Mock(return_value=None) @@ -292,10 +298,12 @@ class TestSegmentServiceCreateSegment: mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment with ( - patch("services.dataset_service.redis_client.lock") as mock_lock, - patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash") as mock_hash, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, + patch( + "services.dataset_service.VectorService.create_segments_vector", autospec=True + ) as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_lock.return_value.__enter__ = Mock() mock_lock.return_value.__exit__ = Mock(return_value=None) @@ -317,7 +325,7 @@ class TestSegmentServiceUpdateSegment: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db @pytest.fixture @@ -338,10 +346,10 @@ class TestSegmentServiceUpdateSegment: mock_db_session.query.return_value.where.return_value.first.return_value = segment with ( - patch("services.dataset_service.redis_client.get") as mock_redis_get, - patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash") as mock_hash, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, + patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_redis_get.return_value = None # Not indexing mock_hash.return_value = "new-hash" @@ -368,10 +376,10 @@ class TestSegmentServiceUpdateSegment: args = SegmentUpdateArgs(enabled=False) with ( - patch("services.dataset_service.redis_client.get") as mock_redis_get, - patch("services.dataset_service.redis_client.setex") as mock_redis_setex, - patch("services.dataset_service.disable_segment_from_index_task") as mock_task, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, + patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex, + patch("services.dataset_service.disable_segment_from_index_task", autospec=True) as mock_task, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_redis_get.return_value = None mock_now.return_value = "2024-01-01T00:00:00" @@ -394,7 +402,7 @@ class TestSegmentServiceUpdateSegment: dataset = SegmentTestDataFactory.create_dataset_mock() args = SegmentUpdateArgs(content="Updated content") - with patch("services.dataset_service.redis_client.get") as mock_redis_get: + with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: mock_redis_get.return_value = "1" # Indexing in progress # Act & Assert @@ -409,7 +417,7 @@ class TestSegmentServiceUpdateSegment: dataset = SegmentTestDataFactory.create_dataset_mock() args = SegmentUpdateArgs(content="Updated content") - with patch("services.dataset_service.redis_client.get") as mock_redis_get: + with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: mock_redis_get.return_value = None # Act & Assert @@ -427,10 +435,10 @@ class TestSegmentServiceUpdateSegment: mock_db_session.query.return_value.where.return_value.first.return_value = segment with ( - patch("services.dataset_service.redis_client.get") as mock_redis_get, - patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash") as mock_hash, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, + patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_redis_get.return_value = None mock_hash.return_value = "new-hash" @@ -456,7 +464,7 @@ class TestSegmentServiceDeleteSegment: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db def test_delete_segment_success(self, mock_db_session): @@ -471,10 +479,10 @@ class TestSegmentServiceDeleteSegment: mock_db_session.scalars.return_value = mock_scalars with ( - patch("services.dataset_service.redis_client.get") as mock_redis_get, - patch("services.dataset_service.redis_client.setex") as mock_redis_setex, - patch("services.dataset_service.delete_segment_from_index_task") as mock_task, - patch("services.dataset_service.select") as mock_select, + patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, + patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex, + patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, + patch("services.dataset_service.select", autospec=True) as mock_select, ): mock_redis_get.return_value = None mock_select.return_value.where.return_value = mock_select @@ -495,8 +503,8 @@ class TestSegmentServiceDeleteSegment: dataset = SegmentTestDataFactory.create_dataset_mock() with ( - patch("services.dataset_service.redis_client.get") as mock_redis_get, - patch("services.dataset_service.delete_segment_from_index_task") as mock_task, + patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, + patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, ): mock_redis_get.return_value = None @@ -515,7 +523,7 @@ class TestSegmentServiceDeleteSegment: document = SegmentTestDataFactory.create_document_mock() dataset = SegmentTestDataFactory.create_dataset_mock() - with patch("services.dataset_service.redis_client.get") as mock_redis_get: + with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get: mock_redis_get.return_value = "1" # Deletion in progress # Act & Assert @@ -529,7 +537,7 @@ class TestSegmentServiceDeleteSegments: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db @pytest.fixture @@ -562,8 +570,8 @@ class TestSegmentServiceDeleteSegments: mock_db_session.scalars.return_value = mock_scalars with ( - patch("services.dataset_service.delete_segment_from_index_task") as mock_task, - patch("services.dataset_service.select") as mock_select_func, + patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task, + patch("services.dataset_service.select", autospec=True) as mock_select_func, ): mock_select_func.return_value = mock_select @@ -594,7 +602,7 @@ class TestSegmentServiceUpdateSegmentsStatus: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db @pytest.fixture @@ -623,9 +631,9 @@ class TestSegmentServiceUpdateSegmentsStatus: mock_db_session.scalars.return_value = mock_scalars with ( - patch("services.dataset_service.redis_client.get") as mock_redis_get, - patch("services.dataset_service.enable_segments_to_index_task") as mock_task, - patch("services.dataset_service.select") as mock_select_func, + patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, + patch("services.dataset_service.enable_segments_to_index_task", autospec=True) as mock_task, + patch("services.dataset_service.select", autospec=True) as mock_select_func, ): mock_redis_get.return_value = None mock_select_func.return_value = mock_select @@ -657,10 +665,10 @@ class TestSegmentServiceUpdateSegmentsStatus: mock_db_session.scalars.return_value = mock_scalars with ( - patch("services.dataset_service.redis_client.get") as mock_redis_get, - patch("services.dataset_service.disable_segments_from_index_task") as mock_task, - patch("services.dataset_service.naive_utc_now") as mock_now, - patch("services.dataset_service.select") as mock_select_func, + patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get, + patch("services.dataset_service.disable_segments_from_index_task", autospec=True) as mock_task, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, + patch("services.dataset_service.select", autospec=True) as mock_select_func, ): mock_redis_get.return_value = None mock_now.return_value = "2024-01-01T00:00:00" @@ -693,7 +701,7 @@ class TestSegmentServiceGetSegments: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db @pytest.fixture @@ -771,7 +779,7 @@ class TestSegmentServiceGetSegmentById: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db def test_get_segment_by_id_success(self, mock_db_session): @@ -814,7 +822,7 @@ class TestSegmentServiceGetChildChunks: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db @pytest.fixture @@ -876,7 +884,7 @@ class TestSegmentServiceGetChildChunkById: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db def test_get_child_chunk_by_id_success(self, mock_db_session): @@ -919,7 +927,7 @@ class TestSegmentServiceCreateChildChunk: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db @pytest.fixture @@ -942,9 +950,11 @@ class TestSegmentServiceCreateChildChunk: mock_db_session.query.return_value = mock_query with ( - patch("services.dataset_service.redis_client.lock") as mock_lock, - patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, + patch( + "services.dataset_service.VectorService.create_child_chunk_vector", autospec=True + ) as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, ): mock_lock.return_value.__enter__ = Mock() mock_lock.return_value.__exit__ = Mock(return_value=None) @@ -972,9 +982,11 @@ class TestSegmentServiceCreateChildChunk: mock_db_session.query.return_value = mock_query with ( - patch("services.dataset_service.redis_client.lock") as mock_lock, - patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service, - patch("services.dataset_service.helper.generate_text_hash") as mock_hash, + patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock, + patch( + "services.dataset_service.VectorService.create_child_chunk_vector", autospec=True + ) as mock_vector_service, + patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, ): mock_lock.return_value.__enter__ = Mock() mock_lock.return_value.__exit__ = Mock(return_value=None) @@ -994,7 +1006,7 @@ class TestSegmentServiceUpdateChildChunk: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db @pytest.fixture @@ -1014,8 +1026,10 @@ class TestSegmentServiceUpdateChildChunk: dataset = SegmentTestDataFactory.create_dataset_mock() with ( - patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch( + "services.dataset_service.VectorService.update_child_chunk_vector", autospec=True + ) as mock_vector_service, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_now.return_value = "2024-01-01T00:00:00" @@ -1040,8 +1054,10 @@ class TestSegmentServiceUpdateChildChunk: dataset = SegmentTestDataFactory.create_dataset_mock() with ( - patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service, - patch("services.dataset_service.naive_utc_now") as mock_now, + patch( + "services.dataset_service.VectorService.update_child_chunk_vector", autospec=True + ) as mock_vector_service, + patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): mock_vector_service.side_effect = Exception("Vector indexing failed") mock_now.return_value = "2024-01-01T00:00:00" @@ -1059,7 +1075,7 @@ class TestSegmentServiceDeleteChildChunk: @pytest.fixture def mock_db_session(self): """Mock database session.""" - with patch("services.dataset_service.db.session") as mock_db: + with patch("services.dataset_service.db.session", autospec=True) as mock_db: yield mock_db def test_delete_child_chunk_success(self, mock_db_session): @@ -1068,7 +1084,9 @@ class TestSegmentServiceDeleteChildChunk: chunk = SegmentTestDataFactory.create_child_chunk_mock() dataset = SegmentTestDataFactory.create_dataset_mock() - with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service: + with patch( + "services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True + ) as mock_vector_service: # Act SegmentService.delete_child_chunk(chunk, dataset) @@ -1083,7 +1101,9 @@ class TestSegmentServiceDeleteChildChunk: chunk = SegmentTestDataFactory.create_child_chunk_mock() dataset = SegmentTestDataFactory.create_dataset_mock() - with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service: + with patch( + "services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True + ) as mock_vector_service: mock_vector_service.side_effect = Exception("Vector deletion failed") # Act & Assert diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 1fc45d1c35..635c86a14b 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1064,6 +1064,67 @@ class TestRegisterService: # ==================== Registration Tests ==================== + def test_create_account_and_tenant_calls_default_workspace_join_when_enterprise_enabled( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Enterprise-only side effect should be invoked when ENTERPRISE_ENABLED is True.""" + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + + result = AccountService.create_account_and_tenant( + email="test@example.com", + name="Test User", + interface_language="en-US", + password=None, + ) + + assert result == mock_account + mock_create_workspace.assert_called_once_with(account=mock_account) + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + + def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False.""" + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False) + + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + + AccountService.create_account_and_tenant( + email="test@example.com", + name="Test User", + interface_language="en-US", + password=None, + ) + + mock_create_workspace.assert_called_once_with(account=mock_account) + mock_join_default_workspace.assert_not_called() + def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies): """Test successful account registration.""" # Setup mocks @@ -1115,6 +1176,65 @@ class TestRegisterService: mock_event.send.assert_called_once_with(mock_tenant) self._assert_database_operations_called(mock_db_dependencies["db"]) + def test_register_calls_default_workspace_join_when_enterprise_enabled( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Enterprise-only side effect should be invoked after successful register commit.""" + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + + result = RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + create_workspace_required=False, + ) + + assert result == mock_account + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + + def test_register_does_not_call_default_workspace_join_when_enterprise_disabled( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False.""" + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False) + + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + + RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + create_workspace_required=False, + ) + + mock_join_default_workspace.assert_not_called() + def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies): """Test account registration with OAuth integration.""" # Setup mocks diff --git a/api/tests/unit_tests/services/test_archive_workflow_run_logs.py b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py index ef62dacd6b..eadcf48b2e 100644 --- a/api/tests/unit_tests/services/test_archive_workflow_run_logs.py +++ b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py @@ -15,8 +15,8 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME class TestWorkflowRunArchiver: """Tests for the WorkflowRunArchiver class.""" - @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") - @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage") + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config", autospec=True) + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage", autospec=True) def test_archiver_initialization(self, mock_get_storage, mock_config): """Test archiver can be initialized with various options.""" from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 2467e01993..5d67469105 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -214,7 +214,7 @@ def factory(): class TestAudioServiceASR: """Test speech-to-text (ASR) operations.""" - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in CHAT mode.""" # Arrange @@ -226,9 +226,7 @@ class TestAudioServiceASR: file = factory.create_file_storage_mock() # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.invoke_speech2text.return_value = "Transcribed text" mock_model_manager.get_default_model_instance.return_value = mock_model_instance @@ -242,7 +240,7 @@ class TestAudioServiceASR: call_args = mock_model_instance.invoke_speech2text.call_args assert call_args.kwargs["user"] == "user-123" - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in ADVANCED_CHAT mode.""" # Arrange @@ -254,9 +252,7 @@ class TestAudioServiceASR: file = factory.create_file_storage_mock() # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text" mock_model_manager.get_default_model_instance.return_value = mock_model_instance @@ -351,7 +347,7 @@ class TestAudioServiceASR: with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"): AudioService.transcript_asr(app_model=app, file=file) - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that ASR raises error when no model instance is available.""" # Arrange @@ -363,8 +359,7 @@ class TestAudioServiceASR: file = factory.create_file_storage_mock() # Mock ModelManager to return None - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager + mock_model_manager = mock_model_manager_class.return_value mock_model_manager.get_default_model_instance.return_value = None # Act & Assert @@ -375,7 +370,7 @@ class TestAudioServiceASR: class TestAudioServiceTTS: """Test text-to-speech (TTS) operations.""" - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory): """Test successful TTS with text input.""" # Arrange @@ -388,9 +383,7 @@ class TestAudioServiceTTS: ) # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.invoke_tts.return_value = b"audio data" mock_model_manager.get_default_model_instance.return_value = mock_model_instance @@ -412,8 +405,8 @@ class TestAudioServiceTTS: voice="en-US-Neural", ) - @patch("services.audio_service.db.session") - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory): """Test successful TTS with message ID.""" # Arrange @@ -437,9 +430,7 @@ class TestAudioServiceTTS: mock_query.first.return_value = message # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.invoke_tts.return_value = b"audio from message" mock_model_manager.get_default_model_instance.return_value = mock_model_instance @@ -454,7 +445,7 @@ class TestAudioServiceTTS: assert result == b"audio from message" mock_model_instance.invoke_tts.assert_called_once() - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory): """Test TTS uses default voice when none specified.""" # Arrange @@ -467,9 +458,7 @@ class TestAudioServiceTTS: ) # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.invoke_tts.return_value = b"audio data" mock_model_manager.get_default_model_instance.return_value = mock_model_instance @@ -486,7 +475,7 @@ class TestAudioServiceTTS: call_args = mock_model_instance.invoke_tts.call_args assert call_args.kwargs["voice"] == "default-voice" - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory): """Test TTS gets first available voice when none is configured.""" # Arrange @@ -499,9 +488,7 @@ class TestAudioServiceTTS: ) # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}] mock_model_instance.invoke_tts.return_value = b"audio data" @@ -518,8 +505,8 @@ class TestAudioServiceTTS: call_args = mock_model_instance.invoke_tts.call_args assert call_args.kwargs["voice"] == "auto-voice" - @patch("services.audio_service.WorkflowService") - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.WorkflowService", autospec=True) + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_workflow_mode_with_draft( self, mock_model_manager_class, mock_workflow_service_class, factory ): @@ -533,14 +520,11 @@ class TestAudioServiceTTS: ) # Mock WorkflowService - mock_workflow_service = MagicMock() - mock_workflow_service_class.return_value = mock_workflow_service + mock_workflow_service = mock_workflow_service_class.return_value mock_workflow_service.get_draft_workflow.return_value = draft_workflow # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.invoke_tts.return_value = b"draft audio" mock_model_manager.get_default_model_instance.return_value = mock_model_instance @@ -565,7 +549,7 @@ class TestAudioServiceTTS: with pytest.raises(ValueError, match="Text is required"): AudioService.transcript_tts(app_model=app, text=None) - @patch("services.audio_service.db.session") + @patch("services.audio_service.db.session", autospec=True) def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory): """Test that TTS returns None for invalid message ID format.""" # Arrange @@ -580,7 +564,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session") + @patch("services.audio_service.db.session", autospec=True) def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory): """Test that TTS returns None when message doesn't exist.""" # Arrange @@ -601,7 +585,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.db.session") + @patch("services.audio_service.db.session", autospec=True) def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory): """Test that TTS returns None when message answer is empty.""" # Arrange @@ -627,7 +611,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory): """Test that TTS raises error when no voices are available.""" # Arrange @@ -640,9 +624,7 @@ class TestAudioServiceTTS: ) # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.get_tts_voices.return_value = [] # No voices available mock_model_manager.get_default_model_instance.return_value = mock_model_instance @@ -655,7 +637,7 @@ class TestAudioServiceTTS: class TestAudioServiceTTSVoices: """Test TTS voice listing operations.""" - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_voices_success(self, mock_model_manager_class, factory): """Test successful retrieval of TTS voices.""" # Arrange @@ -668,9 +650,7 @@ class TestAudioServiceTTSVoices: ] # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.get_tts_voices.return_value = expected_voices mock_model_manager.get_default_model_instance.return_value = mock_model_instance @@ -682,7 +662,7 @@ class TestAudioServiceTTSVoices: assert result == expected_voices mock_model_instance.get_tts_voices.assert_called_once_with(language) - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that TTS voices raises error when no model instance is available.""" # Arrange @@ -690,15 +670,14 @@ class TestAudioServiceTTSVoices: language = "en-US" # Mock ModelManager to return None - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager + mock_model_manager = mock_model_manager_class.return_value mock_model_manager.get_default_model_instance.return_value = None # Act & Assert with pytest.raises(ProviderNotSupportTextToSpeechServiceError): AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language) - @patch("services.audio_service.ModelManager") + @patch("services.audio_service.ModelManager", autospec=True) def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory): """Test that TTS voices propagates exceptions from model instance.""" # Arrange @@ -706,9 +685,7 @@ class TestAudioServiceTTSVoices: language = "en-US" # Mock ModelManager - mock_model_manager = MagicMock() - mock_model_manager_class.return_value = mock_model_manager - + mock_model_manager = mock_model_manager_class.return_value mock_model_instance = MagicMock() mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error") mock_model_manager.get_default_model_instance.return_value = mock_model_instance diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 0661c15623..d8ecdf45fd 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -237,9 +237,9 @@ class TestConversationServiceSummarization: titles based on the first message. """ - @patch("services.conversation_service.db.session") - @patch("services.conversation_service.ConversationService.get_conversation") - @patch("services.conversation_service.ConversationService.auto_generate_name") + @patch("services.conversation_service.db.session", autospec=True) + @patch("services.conversation_service.ConversationService.get_conversation", autospec=True) + @patch("services.conversation_service.ConversationService.auto_generate_name", autospec=True) def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session): """ Test renaming conversation with auto-generation enabled. diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py index 80cce81e89..a1d2f6410c 100644 --- a/api/tests/unit_tests/services/test_dataset_service.py +++ b/api/tests/unit_tests/services/test_dataset_service.py @@ -1,922 +1,45 @@ -""" -Comprehensive unit tests for DatasetService. +"""Unit tests for non-SQL DocumentService orchestration behaviors. -This test suite provides complete coverage of dataset management operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -## Test Coverage - -### 1. Dataset Creation (TestDatasetServiceCreateDataset) -Tests the creation of knowledge base datasets with various configurations: -- Internal datasets (provider='vendor') with economy or high-quality indexing -- External datasets (provider='external') connected to third-party APIs -- Embedding model configuration for semantic search -- Duplicate name validation -- Permission and access control setup - -### 2. Dataset Updates (TestDatasetServiceUpdateDataset) -Tests modification of existing dataset settings: -- Basic field updates (name, description, permission) -- Indexing technique switching (economy ↔ high_quality) -- Embedding model changes with vector index rebuilding -- Retrieval configuration updates -- External knowledge binding updates - -### 3. Dataset Deletion (TestDatasetServiceDeleteDataset) -Tests safe deletion with cascade cleanup: -- Normal deletion with documents and embeddings -- Empty dataset deletion (regression test for #27073) -- Permission verification -- Event-driven cleanup (vector DB, file storage) - -### 4. Document Indexing (TestDatasetServiceDocumentIndexing) -Tests async document processing operations: -- Pause/resume indexing for resource management -- Retry failed documents -- Status transitions through indexing pipeline -- Redis-based concurrency control - -### 5. Retrieval Configuration (TestDatasetServiceRetrievalConfiguration) -Tests search and ranking settings: -- Search method configuration (semantic, full-text, hybrid) -- Top-k and score threshold tuning -- Reranking model integration for improved relevance - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, Redis, model providers) - are mocked to ensure fast, isolated unit tests -- **Factory Pattern**: DatasetServiceTestDataFactory provides consistent test data -- **Fixtures**: Pytest fixtures set up common mock configurations per test class -- **Assertions**: Each test verifies both the return value and all side effects - (database operations, event signals, async task triggers) - -## Key Concepts - -**Indexing Techniques:** -- economy: Keyword-based search (fast, less accurate) -- high_quality: Vector embeddings for semantic search (slower, more accurate) - -**Dataset Providers:** -- vendor: Internal storage and indexing -- external: Third-party knowledge sources via API - -**Document Lifecycle:** -waiting β†’ parsing β†’ cleaning β†’ splitting β†’ indexing β†’ completed (or error) +This file intentionally keeps only collaborator-oriented document indexing +orchestration tests. SQL-backed dataset lifecycle cases are covered by +integration tests under testcontainers. """ -from unittest.mock import Mock, create_autospec, patch -from uuid import uuid4 +from unittest.mock import Mock, patch import pytest -from core.model_runtime.entities.model_entities import ModelType -from models.account import Account, TenantAccountRole -from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings -from services.dataset_service import DatasetService -from services.entities.knowledge_entities.knowledge_entities import RetrievalModel -from services.errors.dataset import DatasetNameDuplicateError +from models.dataset import Document +from services.errors.document import DocumentIndexingError -class DatasetServiceTestDataFactory: - """ - Factory class for creating test data and mock objects. - - This factory provides reusable methods to create mock objects for testing. - Using a factory pattern ensures consistency across tests and reduces code duplication. - All methods return properly configured Mock objects that simulate real model instances. - """ - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - tenant_id: str = "tenant-123", - role: TenantAccountRole = TenantAccountRole.NORMAL, - **kwargs, - ) -> Mock: - """ - Create a mock account with specified attributes. - - Args: - account_id: Unique identifier for the account - tenant_id: Tenant ID the account belongs to - role: User role (NORMAL, ADMIN, etc.) - **kwargs: Additional attributes to set on the mock - - Returns: - Mock: A properly configured Account mock object - """ - account = create_autospec(Account, instance=True) - account.id = account_id - account.current_tenant_id = tenant_id - account.current_role = role - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - name: str = "Test Dataset", - tenant_id: str = "tenant-123", - created_by: str = "user-123", - provider: str = "vendor", - indexing_technique: str | None = "high_quality", - **kwargs, - ) -> Mock: - """ - Create a mock dataset with specified attributes. - - Args: - dataset_id: Unique identifier for the dataset - name: Display name of the dataset - tenant_id: Tenant ID the dataset belongs to - created_by: User ID who created the dataset - provider: Dataset provider type ('vendor' for internal, 'external' for external) - indexing_technique: Indexing method ('high_quality', 'economy', or None) - **kwargs: Additional attributes (embedding_model, retrieval_model, etc.) - - Returns: - Mock: A properly configured Dataset mock object - """ - dataset = create_autospec(Dataset, instance=True) - dataset.id = dataset_id - dataset.name = name - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.provider = provider - dataset.indexing_technique = indexing_technique - dataset.permission = kwargs.get("permission", DatasetPermissionEnum.ONLY_ME) - dataset.embedding_model_provider = kwargs.get("embedding_model_provider") - dataset.embedding_model = kwargs.get("embedding_model") - dataset.collection_binding_id = kwargs.get("collection_binding_id") - dataset.retrieval_model = kwargs.get("retrieval_model") - dataset.description = kwargs.get("description") - dataset.doc_form = kwargs.get("doc_form") - for key, value in kwargs.items(): - if not hasattr(dataset, key): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock: - """ - Create a mock embedding model for high-quality indexing. - - Embedding models are used to convert text into vector representations - for semantic search capabilities. - - Args: - model: Model name (e.g., 'text-embedding-ada-002') - provider: Model provider (e.g., 'openai', 'cohere') - - Returns: - Mock: Embedding model mock with model and provider attributes - """ - embedding_model = Mock() - embedding_model.model_name = model - embedding_model.provider = provider - return embedding_model - - @staticmethod - def create_retrieval_model_mock() -> Mock: - """ - Create a mock retrieval model configuration. - - Retrieval models define how documents are searched and ranked, - including search method, top-k results, and score thresholds. - - Returns: - Mock: RetrievalModel mock with model_dump() method - """ - retrieval_model = Mock(spec=RetrievalModel) - retrieval_model.model_dump.return_value = { - "search_method": "semantic_search", - "top_k": 2, - "score_threshold": 0.0, - } - retrieval_model.reranking_model = None - return retrieval_model - - @staticmethod - def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock: - """ - Create a mock collection binding for vector database. - - Collection bindings link datasets to their vector storage locations - in the vector database (e.g., Qdrant, Weaviate). - - Args: - binding_id: Unique identifier for the collection binding - - Returns: - Mock: Collection binding mock object - """ - binding = Mock() - binding.id = binding_id - return binding - - @staticmethod - def create_external_binding_mock( - dataset_id: str = "dataset-123", - external_knowledge_id: str = "knowledge-123", - external_knowledge_api_id: str = "api-123", - ) -> Mock: - """ - Create a mock external knowledge binding. - - External knowledge bindings connect datasets to external knowledge sources - (e.g., third-party APIs, external databases) for retrieval. - - Args: - dataset_id: Dataset ID this binding belongs to - external_knowledge_id: External knowledge source identifier - external_knowledge_api_id: External API configuration identifier - - Returns: - Mock: ExternalKnowledgeBindings mock object - """ - binding = Mock(spec=ExternalKnowledgeBindings) - binding.dataset_id = dataset_id - binding.external_knowledge_id = external_knowledge_id - binding.external_knowledge_api_id = external_knowledge_api_id - return binding +class DatasetServiceUnitDataFactory: + """Factory for creating lightweight document doubles used in unit tests.""" @staticmethod def create_document_mock( document_id: str = "doc-123", dataset_id: str = "dataset-123", indexing_status: str = "completed", - **kwargs, + is_paused: bool = False, ) -> Mock: - """ - Create a mock document for testing document operations. - - Documents are the individual files/content items within a dataset - that go through indexing, parsing, and chunking processes. - - Args: - document_id: Unique identifier for the document - dataset_id: Parent dataset ID - indexing_status: Current status ('waiting', 'indexing', 'completed', 'error') - **kwargs: Additional attributes (is_paused, enabled, archived, etc.) - - Returns: - Mock: Document mock object - """ + """Create a document-shaped mock for DocumentService orchestration tests.""" document = Mock(spec=Document) document.id = document_id document.dataset_id = dataset_id document.indexing_status = indexing_status - for key, value in kwargs.items(): - setattr(document, key, value) + document.is_paused = is_paused + document.paused_by = None + document.paused_at = None return document -# ==================== Dataset Creation Tests ==================== - - -class TestDatasetServiceCreateDataset: - """ - Comprehensive unit tests for dataset creation logic. - - Covers: - - Internal dataset creation with various indexing techniques - - External dataset creation with external knowledge bindings - - RAG pipeline dataset creation - - Error handling for duplicate names and missing configurations - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """ - Common mock setup for dataset service dependencies. - - This fixture patches all external dependencies that DatasetService.create_empty_dataset - interacts with, including: - - db.session: Database operations (query, add, commit) - - ModelManager: Embedding model management - - check_embedding_model_setting: Validates embedding model configuration - - check_reranking_model_setting: Validates reranking model configuration - - ExternalDatasetService: Handles external knowledge API operations - - Yields: - dict: Dictionary of mocked dependencies for use in tests - """ - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.ModelManager") as mock_model_manager, - patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, - patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, - patch("services.dataset_service.ExternalDatasetService") as mock_external_service, - ): - yield { - "db_session": mock_db, - "model_manager": mock_model_manager, - "check_embedding": mock_check_embedding, - "check_reranking": mock_check_reranking, - "external_service": mock_external_service, - } - - def test_create_internal_dataset_basic_success(self, mock_dataset_service_dependencies): - """ - Test successful creation of basic internal dataset. - - Verifies that a dataset can be created with minimal configuration: - - No indexing technique specified (None) - - Default permission (only_me) - - Vendor provider (internal dataset) - - This is the simplest dataset creation scenario. - """ - # Arrange: Set up test data and mocks - tenant_id = str(uuid4()) - account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Test Dataset" - description = "Test description" - - # Mock database query to return None (no duplicate name exists) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock database session operations for dataset creation - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() # Tracks dataset being added to session - mock_db.flush = Mock() # Flushes to get dataset ID - mock_db.commit = Mock() # Commits transaction - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=description, - indexing_technique=None, - account=account, - ) - - # Assert - assert result is not None - assert result.name == name - assert result.description == description - assert result.tenant_id == tenant_id - assert result.created_by == account.id - assert result.updated_by == account.id - assert result.provider == "vendor" - assert result.permission == "only_me" - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_economy_indexing(self, mock_dataset_service_dependencies): - """Test successful creation of internal dataset with economy indexing.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Economy Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="economy", - account=account, - ) - - # Assert - assert result.indexing_technique == "economy" - assert result.embedding_model_provider is None - assert result.embedding_model is None - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_high_quality_indexing(self, mock_dataset_service_dependencies): - """Test creation with high_quality indexing using default embedding model.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "High Quality Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetServiceTestDataFactory.create_embedding_model_mock() - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_default_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - ) - - # Assert - assert result.indexing_technique == "high_quality" - assert result.embedding_model_provider == embedding_model.provider - assert result.embedding_model == embedding_model.model_name - mock_model_manager_instance.get_default_model_instance.assert_called_once_with( - tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - mock_db.commit.assert_called_once() - - def test_create_dataset_duplicate_name_error(self, mock_dataset_service_dependencies): - """Test error when creating dataset with duplicate name.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Duplicate Dataset" - - # Mock database query to return existing dataset - existing_dataset = DatasetServiceTestDataFactory.create_dataset_mock(name=name, tenant_id=tenant_id) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = existing_dataset - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Act & Assert - with pytest.raises(DatasetNameDuplicateError) as context: - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - ) - - assert f"Dataset with name {name} already exists" in str(context.value) - - def test_create_external_dataset_success(self, mock_dataset_service_dependencies): - """Test successful creation of external dataset with external knowledge binding.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_knowledge_api_id = "api-123" - external_knowledge_id = "knowledge-123" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API - external_api = Mock() - external_api.id = external_knowledge_api_id - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_knowledge_api_id, - external_knowledge_id=external_knowledge_id, - ) - - # Assert - assert result.provider == "external" - assert mock_db.add.call_count == 2 # Dataset + ExternalKnowledgeBinding - mock_db.commit.assert_called_once() - - -# ==================== Dataset Update Tests ==================== - - -class TestDatasetServiceUpdateDataset: - """ - Comprehensive unit tests for dataset update settings. - - Covers: - - Basic field updates (name, description, permission) - - Indexing technique changes (economy <-> high_quality) - - Embedding model updates - - Retrieval configuration updates - - External dataset updates - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.DatasetService._has_dataset_same_name") as mock_has_same_name, - patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.naive_utc_now") as mock_time, - patch( - "services.dataset_service.DatasetService._update_pipeline_knowledge_base_node_data" - ) as mock_update_pipeline, - ): - mock_time.return_value = "2024-01-01T00:00:00" - yield { - "get_dataset": mock_get_dataset, - "has_dataset_same_name": mock_has_same_name, - "check_permission": mock_check_perm, - "db_session": mock_db, - "current_time": "2024-01-01T00:00:00", - "update_pipeline": mock_update_pipeline, - } - - @pytest.fixture - def mock_internal_provider_dependencies(self): - """Mock dependencies for internal dataset provider operations.""" - with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, - patch("services.dataset_service.DatasetCollectionBindingService") as mock_binding_service, - patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, - patch("services.dataset_service.current_user") as mock_current_user, - ): - # Mock current_user as Account instance - mock_current_user_account = DatasetServiceTestDataFactory.create_account_mock( - account_id="user-123", tenant_id="tenant-123" - ) - mock_current_user.return_value = mock_current_user_account - mock_current_user.current_tenant_id = "tenant-123" - mock_current_user.id = "user-123" - # Make isinstance check pass - mock_current_user.__class__ = Account - - yield { - "model_manager": mock_model_manager, - "get_binding": mock_binding_service.get_dataset_collection_binding, - "task": mock_task, - "current_user": mock_current_user, - } - - @pytest.fixture - def mock_external_provider_dependencies(self): - """Mock dependencies for external dataset provider operations.""" - with ( - patch("services.dataset_service.Session") as mock_session, - patch("services.dataset_service.db.engine") as mock_engine, - ): - yield mock_session - - def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies): - """Test successful update of internal dataset with basic fields.""" - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock( - provider="vendor", - indexing_technique="high_quality", - embedding_model_provider="openai", - embedding_model="text-embedding-ada-002", - collection_binding_id="binding-123", - ) - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetServiceTestDataFactory.create_account_mock() - - update_data = { - "name": "new_name", - "description": "new_description", - "indexing_technique": "high_quality", - "retrieval_model": "new_model", - "embedding_model_provider": "openai", - "embedding_model": "text-embedding-ada-002", - } - - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - # Act - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Assert - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies[ - "db_session" - ].query.return_value.filter_by.return_value.update.assert_called_once() - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - assert result == dataset - - def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies): - """Test error when updating non-existent dataset.""" - # Arrange - mock_dataset_service_dependencies["get_dataset"].return_value = None - user = DatasetServiceTestDataFactory.create_account_mock() - - # Act & Assert - with pytest.raises(ValueError) as context: - DatasetService.update_dataset("non-existent", {}, user) - - assert "Dataset not found" in str(context.value) - - def test_update_dataset_duplicate_name_error(self, mock_dataset_service_dependencies): - """Test error when updating dataset to duplicate name.""" - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock() - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = True - - user = DatasetServiceTestDataFactory.create_account_mock() - update_data = {"name": "duplicate_name"} - - # Act & Assert - with pytest.raises(ValueError) as context: - DatasetService.update_dataset("dataset-123", update_data, user) - - assert "Dataset name already exists" in str(context.value) - - def test_update_indexing_technique_to_economy( - self, mock_dataset_service_dependencies, mock_internal_provider_dependencies - ): - """Test updating indexing technique from high_quality to economy.""" - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock( - provider="vendor", indexing_technique="high_quality" - ) - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetServiceTestDataFactory.create_account_mock() - - update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"} - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - # Act - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Assert - mock_dataset_service_dependencies[ - "db_session" - ].query.return_value.filter_by.return_value.update.assert_called_once() - # Verify embedding model fields are cleared - call_args = mock_dataset_service_dependencies[ - "db_session" - ].query.return_value.filter_by.return_value.update.call_args[0][0] - assert call_args["embedding_model"] is None - assert call_args["embedding_model_provider"] is None - assert call_args["collection_binding_id"] is None - assert result == dataset - - def test_update_indexing_technique_to_high_quality( - self, mock_dataset_service_dependencies, mock_internal_provider_dependencies - ): - """Test updating indexing technique from economy to high_quality.""" - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetServiceTestDataFactory.create_account_mock() - - # Mock embedding model - embedding_model = DatasetServiceTestDataFactory.create_embedding_model_mock() - mock_internal_provider_dependencies[ - "model_manager" - ].return_value.get_model_instance.return_value = embedding_model - - # Mock collection binding - binding = DatasetServiceTestDataFactory.create_collection_binding_mock() - mock_internal_provider_dependencies["get_binding"].return_value = binding - - update_data = { - "indexing_technique": "high_quality", - "embedding_model_provider": "openai", - "embedding_model": "text-embedding-ada-002", - "retrieval_model": "new_model", - } - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - # Act - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Assert - mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once() - mock_internal_provider_dependencies["get_binding"].assert_called_once() - mock_internal_provider_dependencies["task"].delay.assert_called_once() - call_args = mock_internal_provider_dependencies["task"].delay.call_args[0] - assert call_args[0] == "dataset-123" - assert call_args[1] == "add" - - # Verify return value - assert result == dataset - - # Note: External dataset update test removed due to Flask app context complexity in unit tests - # External dataset functionality is covered by integration tests - - def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge id is missing.""" - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock(provider="external") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetServiceTestDataFactory.create_account_mock() - update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"} - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - # Act & Assert - with pytest.raises(ValueError) as context: - DatasetService.update_dataset("dataset-123", update_data, user) - - assert "External knowledge id is required" in str(context.value) - - -# ==================== Dataset Deletion Tests ==================== - - -class TestDatasetServiceDeleteDataset: - """ - Comprehensive unit tests for dataset deletion with cascade operations. - - Covers: - - Normal dataset deletion with documents - - Empty dataset deletion (no documents) - - Dataset deletion with partial None values - - Permission checks - - Event handling for cascade operations - - Dataset deletion is a critical operation that triggers cascade cleanup: - - Documents and segments are removed from vector database - - File storage is cleaned up - - Related bindings and metadata are deleted - - The dataset_was_deleted event notifies listeners for cleanup - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """ - Common mock setup for dataset deletion dependencies. - - Patches: - - get_dataset: Retrieves the dataset to delete - - check_dataset_permission: Verifies user has delete permission - - db.session: Database operations (delete, commit) - - dataset_was_deleted: Signal/event for cascade cleanup operations - - The dataset_was_deleted signal is crucial - it triggers cleanup handlers - that remove vector embeddings, files, and related data. - """ - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted, - ): - yield { - "get_dataset": mock_get_dataset, - "check_permission": mock_check_perm, - "db_session": mock_db, - "dataset_was_deleted": mock_dataset_was_deleted, - } - - def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies): - """Test successful deletion of a dataset with documents.""" - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" - ) - user = DatasetServiceTestDataFactory.create_account_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies): - """ - Test successful deletion of an empty dataset (no documents, doc_form is None). - - Empty datasets are created but never had documents uploaded. They have: - - doc_form = None (no document format configured) - - indexing_technique = None (no indexing method set) - - This test ensures empty datasets can be deleted without errors. - The event handler should gracefully skip cleanup operations when - there's no actual data to clean up. - - This test provides regression protection for issue #27073 where - deleting empty datasets caused internal server errors. - """ - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None) - user = DatasetServiceTestDataFactory.create_account_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - # Event is sent even for empty datasets - handlers check for None values - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_not_found(self, mock_dataset_service_dependencies): - """Test deletion attempt when dataset doesn't exist.""" - # Arrange - dataset_id = "non-existent-dataset" - user = DatasetServiceTestDataFactory.create_account_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = None - - # Act - result = DatasetService.delete_dataset(dataset_id, user) - - # Assert - assert result is False - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id) - mock_dataset_service_dependencies["check_permission"].assert_not_called() - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called() - mock_dataset_service_dependencies["db_session"].delete.assert_not_called() - mock_dataset_service_dependencies["db_session"].commit.assert_not_called() - - def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies): - """Test deletion of dataset with partial None values (doc_form exists but indexing_technique is None).""" - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None) - user = DatasetServiceTestDataFactory.create_account_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - assert result is True - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - -# ==================== Document Indexing Logic Tests ==================== - - class TestDatasetServiceDocumentIndexing: - """ - Comprehensive unit tests for document indexing logic. - - Covers: - - Document indexing status transitions - - Pause/resume document indexing - - Retry document indexing - - Sync website document indexing - - Document indexing task triggering - - Document indexing is an async process with multiple stages: - 1. waiting: Document queued for processing - 2. parsing: Extracting text from file - 3. cleaning: Removing unwanted content - 4. splitting: Breaking into chunks - 5. indexing: Creating embeddings and storing in vector DB - 6. completed: Successfully indexed - 7. error: Failed at some stage - - Users can pause/resume indexing or retry failed documents. - """ + """Unit tests for pause/recover/retry orchestration without SQL assertions.""" @pytest.fixture def mock_document_service_dependencies(self): - """ - Common mock setup for document service dependencies. - - Patches: - - redis_client: Caches indexing state and prevents concurrent operations - - db.session: Database operations for document status updates - - current_user: User context for tracking who paused/resumed - - Redis is used to: - - Store pause flags (document_{id}_is_paused) - - Prevent duplicate retry operations (document_{id}_is_retried) - - Track active indexing operations (document_{id}_indexing) - """ + """Patch non-SQL collaborators used by DocumentService methods.""" with ( patch("services.dataset_service.redis_client") as mock_redis, patch("services.dataset_service.db.session") as mock_db, @@ -930,271 +53,77 @@ class TestDatasetServiceDocumentIndexing: } def test_pause_document_success(self, mock_document_service_dependencies): - """ - Test successful pause of document indexing. - - Pausing allows users to temporarily stop indexing without canceling it. - This is useful when: - - System resources are needed elsewhere - - User wants to modify document settings before continuing - - Indexing is taking too long and needs to be deferred - - When paused: - - is_paused flag is set to True - - paused_by and paused_at are recorded - - Redis flag prevents indexing worker from processing - - Document remains in current indexing stage - """ + """Pause a document that is currently in an indexable status.""" # Arrange - document = DatasetServiceTestDataFactory.create_document_mock(indexing_status="indexing") - mock_db = mock_document_service_dependencies["db_session"] - mock_redis = mock_document_service_dependencies["redis_client"] + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") # Act from services.dataset_service import DocumentService DocumentService.pause_document(document) - # Assert - Verify pause state is persisted + # Assert assert document.is_paused is True - mock_db.add.assert_called_once_with(document) - mock_db.commit.assert_called_once() - # setnx (set if not exists) prevents race conditions - mock_redis.setnx.assert_called_once() + assert document.paused_by == "user-123" + mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) + mock_document_service_dependencies["db_session"].commit.assert_called_once() + mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( + f"document_{document.id}_is_paused", + "True", + ) def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): - """Test error when pausing document with invalid status.""" + """Raise DocumentIndexingError when pausing a completed document.""" # Arrange - document = DatasetServiceTestDataFactory.create_document_mock(indexing_status="completed") + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") - # Act & Assert + # Act / Assert from services.dataset_service import DocumentService - from services.errors.document import DocumentIndexingError with pytest.raises(DocumentIndexingError): DocumentService.pause_document(document) def test_recover_document_success(self, mock_document_service_dependencies): - """Test successful recovery of paused document indexing.""" + """Recover a paused document and dispatch the recover indexing task.""" # Arrange - document = DatasetServiceTestDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) - mock_db = mock_document_service_dependencies["db_session"] - mock_redis = mock_document_service_dependencies["redis_client"] + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) # Act - with patch("services.dataset_service.recover_document_indexing_task") as mock_task: + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: from services.dataset_service import DocumentService DocumentService.recover_document(document) - # Assert - assert document.is_paused is False - mock_db.add.assert_called_once_with(document) - mock_db.commit.assert_called_once() - mock_redis.delete.assert_called_once() - mock_task.delay.assert_called_once_with(document.dataset_id, document.id) + # Assert + assert document.is_paused is False + assert document.paused_by is None + assert document.paused_at is None + mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) + mock_document_service_dependencies["db_session"].commit.assert_called_once() + mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( + f"document_{document.id}_is_paused" + ) + recover_task.delay.assert_called_once_with(document.dataset_id, document.id) def test_retry_document_indexing_success(self, mock_document_service_dependencies): - """Test successful retry of document indexing.""" + """Reset documents to waiting state and dispatch retry indexing task.""" # Arrange dataset_id = "dataset-123" documents = [ - DatasetServiceTestDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), - DatasetServiceTestDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), ] - mock_db = mock_document_service_dependencies["db_session"] - mock_redis = mock_document_service_dependencies["redis_client"] - mock_redis.get.return_value = None + mock_document_service_dependencies["redis_client"].get.return_value = None # Act - with patch("services.dataset_service.retry_document_indexing_task") as mock_task: + with patch("services.dataset_service.retry_document_indexing_task") as retry_task: from services.dataset_service import DocumentService DocumentService.retry_document(dataset_id, documents) - # Assert - for doc in documents: - assert doc.indexing_status == "waiting" - assert mock_db.add.call_count == len(documents) - # Commit is called once per document - assert mock_db.commit.call_count == len(documents) - mock_task.delay.assert_called_once() - - -# ==================== Retrieval Configuration Tests ==================== - - -class TestDatasetServiceRetrievalConfiguration: - """ - Comprehensive unit tests for retrieval configuration. - - Covers: - - Retrieval model configuration - - Search method configuration - - Top-k and score threshold settings - - Reranking model configuration - - Retrieval configuration controls how documents are searched and ranked: - - Search Methods: - - semantic_search: Uses vector similarity (cosine distance) - - full_text_search: Uses keyword matching (BM25) - - hybrid_search: Combines both methods with weighted scores - - Parameters: - - top_k: Number of results to return (default: 2-10) - - score_threshold: Minimum similarity score (0.0-1.0) - - reranking_enable: Whether to use reranking model for better results - - Reranking: - After initial retrieval, a reranking model (e.g., Cohere rerank) can - reorder results for better relevance. This is more accurate but slower. - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """ - Common mock setup for retrieval configuration tests. - - Patches: - - get_dataset: Retrieves dataset with retrieval configuration - - db.session: Database operations for configuration updates - """ - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.db.session") as mock_db, - ): - yield { - "get_dataset": mock_get_dataset, - "db_session": mock_db, - } - - def test_get_dataset_retrieval_configuration(self, mock_dataset_service_dependencies): - """Test retrieving dataset with retrieval configuration.""" - # Arrange - dataset_id = "dataset-123" - retrieval_model_config = { - "search_method": "semantic_search", - "top_k": 5, - "score_threshold": 0.5, - "reranking_enable": True, - } - dataset = DatasetServiceTestDataFactory.create_dataset_mock( - dataset_id=dataset_id, retrieval_model=retrieval_model_config - ) - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.get_dataset(dataset_id) - # Assert - assert result is not None - assert result.retrieval_model == retrieval_model_config - assert result.retrieval_model["search_method"] == "semantic_search" - assert result.retrieval_model["top_k"] == 5 - assert result.retrieval_model["score_threshold"] == 0.5 - - def test_update_dataset_retrieval_configuration(self, mock_dataset_service_dependencies): - """Test updating dataset retrieval configuration.""" - # Arrange - dataset = DatasetServiceTestDataFactory.create_dataset_mock( - provider="vendor", - indexing_technique="high_quality", - retrieval_model={"search_method": "semantic_search", "top_k": 2}, - ) - - with ( - patch("services.dataset_service.DatasetService._has_dataset_same_name") as mock_has_same_name, - patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, - patch("services.dataset_service.naive_utc_now") as mock_time, - patch( - "services.dataset_service.DatasetService._update_pipeline_knowledge_base_node_data" - ) as mock_update_pipeline, - ): - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - mock_has_same_name.return_value = False - mock_time.return_value = "2024-01-01T00:00:00" - - user = DatasetServiceTestDataFactory.create_account_mock() - - new_retrieval_config = { - "search_method": "full_text_search", - "top_k": 10, - "score_threshold": 0.7, - } - - update_data = { - "indexing_technique": "high_quality", - "retrieval_model": new_retrieval_config, - } - - # Act - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Assert - mock_dataset_service_dependencies[ - "db_session" - ].query.return_value.filter_by.return_value.update.assert_called_once() - call_args = mock_dataset_service_dependencies[ - "db_session" - ].query.return_value.filter_by.return_value.update.call_args[0][0] - assert call_args["retrieval_model"] == new_retrieval_config - assert result == dataset - - def test_create_dataset_with_retrieval_model_and_reranking(self, mock_dataset_service_dependencies): - """Test creating dataset with retrieval model and reranking configuration.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetServiceTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Dataset with Reranking" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock retrieval model with reranking - retrieval_model = Mock(spec=RetrievalModel) - retrieval_model.model_dump.return_value = { - "search_method": "semantic_search", - "top_k": 3, - "score_threshold": 0.6, - "reranking_enable": True, - } - reranking_model = Mock() - reranking_model.reranking_provider_name = "cohere" - reranking_model.reranking_model_name = "rerank-english-v2.0" - retrieval_model.reranking_model = reranking_model - - # Mock model manager - embedding_model = DatasetServiceTestDataFactory.create_embedding_model_mock() - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_default_model_instance.return_value = embedding_model - - with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, - patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, - patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, - ): - mock_model_manager.return_value = mock_model_manager_instance - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - retrieval_model=retrieval_model, - ) - - # Assert - assert result.retrieval_model == retrieval_model.model_dump() - mock_check_reranking.assert_called_once_with(tenant_id, "cohere", "rerank-english-v2.0") - mock_db.commit.assert_called_once() + assert all(document.indexing_status == "waiting" for document in documents) + assert mock_document_service_dependencies["db_session"].add.call_count == 2 + assert mock_document_service_dependencies["db_session"].commit.call_count == 2 + assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 + retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py index babd620ab7..a7e1a011f6 100644 --- a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py @@ -28,10 +28,14 @@ class TestArchivedWorkflowRunDeletion: with ( patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), patch( - "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", + return_value=session_maker, + autospec=True, ), - patch.object(deleter, "_get_workflow_run_repo", return_value=repo), - patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run, + patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True), + patch.object( + deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True + ) as mock_delete_run, ): result = deleter.delete_by_run_id("run-1") @@ -46,7 +50,7 @@ class TestArchivedWorkflowRunDeletion: run.id = "run-1" run.tenant_id = "tenant-1" - with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo: result = deleter._delete_run(run) assert result.success is True diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py index 3b619195c7..67ae2c9142 100644 --- a/api/tests/unit_tests/services/test_messages_clean_service.py +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -402,7 +402,7 @@ class TestBillingDisabledPolicyFilterMessageIds: class TestCreateMessageCleanPolicy: """Unit tests for create_message_clean_policy factory function.""" - @patch("services.retention.conversation.messages_clean_policy.dify_config") + @patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True) def test_billing_disabled_returns_billing_disabled_policy(self, mock_config): """Test that BILLING_ENABLED=False returns BillingDisabledPolicy.""" # Arrange @@ -414,8 +414,8 @@ class TestCreateMessageCleanPolicy: # Assert assert isinstance(policy, BillingDisabledPolicy) - @patch("services.retention.conversation.messages_clean_policy.BillingService") - @patch("services.retention.conversation.messages_clean_policy.dify_config") + @patch("services.retention.conversation.messages_clean_policy.BillingService", autospec=True) + @patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True) def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service): """Test that BillingSandboxPolicy is created with correct internal values.""" # Arrange @@ -554,7 +554,7 @@ class TestMessagesCleanServiceFromDays: MessagesCleanService.from_days(policy=policy, days=-1) # Act - with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0) mock_datetime.datetime.now.return_value = fixed_now mock_datetime.timedelta = datetime.timedelta @@ -586,7 +586,7 @@ class TestMessagesCleanServiceFromDays: dry_run = True # Act - with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) mock_datetime.datetime.now.return_value = fixed_now mock_datetime.timedelta = datetime.timedelta @@ -613,7 +613,7 @@ class TestMessagesCleanServiceFromDays: policy = BillingDisabledPolicy() # Act - with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) mock_datetime.datetime.now.return_value = fixed_now mock_datetime.timedelta = datetime.timedelta diff --git a/api/tests/unit_tests/services/test_recommended_app_service.py b/api/tests/unit_tests/services/test_recommended_app_service.py index 8d6d271689..12f4c0b982 100644 --- a/api/tests/unit_tests/services/test_recommended_app_service.py +++ b/api/tests/unit_tests/services/test_recommended_app_service.py @@ -134,8 +134,8 @@ def factory(): class TestRecommendedAppServiceGetApps: """Test get_recommended_apps_and_categories operations.""" - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory): """Test successful retrieval of recommended apps when apps are returned.""" # Arrange @@ -161,8 +161,8 @@ class TestRecommendedAppServiceGetApps: mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote") mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory): """Test fallback to builtin when no recommended apps are returned.""" # Arrange @@ -199,8 +199,8 @@ class TestRecommendedAppServiceGetApps: # Verify fallback was called with en-US (hardcoded) mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory): """Test fallback when recommended_apps key is None.""" # Arrange @@ -232,8 +232,8 @@ class TestRecommendedAppServiceGetApps: assert result == builtin_response mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once() - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory): """Test retrieval with different language codes.""" # Arrange @@ -262,8 +262,8 @@ class TestRecommendedAppServiceGetApps: assert result["recommended_apps"][0]["id"] == f"app-{language}" mock_instance.get_recommended_apps_and_categories.assert_called_with(language) - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory): """Test that correct factory is selected based on mode.""" # Arrange @@ -292,8 +292,8 @@ class TestRecommendedAppServiceGetApps: class TestRecommendedAppServiceGetDetail: """Test get_recommend_app_detail operations.""" - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory): """Test successful retrieval of app detail.""" # Arrange @@ -324,8 +324,8 @@ class TestRecommendedAppServiceGetDetail: assert result["name"] == "Productivity App" mock_instance.get_recommend_app_detail.assert_called_once_with(app_id) - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory): """Test app detail retrieval with different factory modes.""" # Arrange @@ -352,8 +352,8 @@ class TestRecommendedAppServiceGetDetail: assert result["name"] == f"App from {mode}" mock_factory_class.get_recommend_app_factory.assert_called_with(mode) - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory): """Test that None is returned when app is not found.""" # Arrange @@ -375,8 +375,8 @@ class TestRecommendedAppServiceGetDetail: assert result is None mock_instance.get_recommend_app_detail.assert_called_once_with(app_id) - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory): """Test handling of empty dict response.""" # Arrange @@ -397,8 +397,8 @@ class TestRecommendedAppServiceGetDetail: # Assert assert result == {} - @patch("services.recommended_app_service.RecommendAppRetrievalFactory") - @patch("services.recommended_app_service.dify_config") + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory): """Test app detail with complex model configuration.""" # Arrange diff --git a/api/tests/unit_tests/services/test_saved_message_service.py b/api/tests/unit_tests/services/test_saved_message_service.py index 15e37a9008..87b946fe46 100644 --- a/api/tests/unit_tests/services/test_saved_message_service.py +++ b/api/tests/unit_tests/services/test_saved_message_service.py @@ -201,8 +201,8 @@ def factory(): class TestSavedMessageServicePagination: """Test saved message pagination operations.""" - @patch("services.saved_message_service.MessageService.pagination_by_last_id") - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) + @patch("services.saved_message_service.db.session", autospec=True) def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory): """Test pagination with an Account user.""" # Arrange @@ -247,8 +247,8 @@ class TestSavedMessageServicePagination: include_ids=["msg-0", "msg-1", "msg-2"], ) - @patch("services.saved_message_service.MessageService.pagination_by_last_id") - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) + @patch("services.saved_message_service.db.session", autospec=True) def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory): """Test pagination with an EndUser.""" # Arrange @@ -301,8 +301,8 @@ class TestSavedMessageServicePagination: with pytest.raises(ValueError, match="User is required"): SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20) - @patch("services.saved_message_service.MessageService.pagination_by_last_id") - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) + @patch("services.saved_message_service.db.session", autospec=True) def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory): """Test pagination with last_id parameter.""" # Arrange @@ -340,8 +340,8 @@ class TestSavedMessageServicePagination: call_args = mock_message_pagination.call_args assert call_args.kwargs["last_id"] == last_id - @patch("services.saved_message_service.MessageService.pagination_by_last_id") - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True) + @patch("services.saved_message_service.db.session", autospec=True) def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory): """Test pagination when user has no saved messages.""" # Arrange @@ -377,8 +377,8 @@ class TestSavedMessageServicePagination: class TestSavedMessageServiceSave: """Test save message operations.""" - @patch("services.saved_message_service.MessageService.get_message") - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.MessageService.get_message", autospec=True) + @patch("services.saved_message_service.db.session", autospec=True) def test_save_message_for_account(self, mock_db_session, mock_get_message, factory): """Test saving a message for an Account user.""" # Arrange @@ -407,8 +407,8 @@ class TestSavedMessageServiceSave: assert saved_message.created_by_role == "account" mock_db_session.commit.assert_called_once() - @patch("services.saved_message_service.MessageService.get_message") - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.MessageService.get_message", autospec=True) + @patch("services.saved_message_service.db.session", autospec=True) def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory): """Test saving a message for an EndUser.""" # Arrange @@ -437,7 +437,7 @@ class TestSavedMessageServiceSave: assert saved_message.created_by_role == "end_user" mock_db_session.commit.assert_called_once() - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.db.session", autospec=True) def test_save_without_user_does_nothing(self, mock_db_session, factory): """Test that saving without user is a no-op.""" # Arrange @@ -451,8 +451,8 @@ class TestSavedMessageServiceSave: mock_db_session.add.assert_not_called() mock_db_session.commit.assert_not_called() - @patch("services.saved_message_service.MessageService.get_message") - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.MessageService.get_message", autospec=True) + @patch("services.saved_message_service.db.session", autospec=True) def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory): """Test that saving an already saved message is idempotent.""" # Arrange @@ -480,8 +480,8 @@ class TestSavedMessageServiceSave: mock_db_session.commit.assert_not_called() mock_get_message.assert_not_called() - @patch("services.saved_message_service.MessageService.get_message") - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.MessageService.get_message", autospec=True) + @patch("services.saved_message_service.db.session", autospec=True) def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory): """Test that save validates message exists through MessageService.""" # Arrange @@ -508,7 +508,7 @@ class TestSavedMessageServiceSave: class TestSavedMessageServiceDelete: """Test delete saved message operations.""" - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.db.session", autospec=True) def test_delete_saved_message_for_account(self, mock_db_session, factory): """Test deleting a saved message for an Account user.""" # Arrange @@ -535,7 +535,7 @@ class TestSavedMessageServiceDelete: mock_db_session.delete.assert_called_once_with(saved_message) mock_db_session.commit.assert_called_once() - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.db.session", autospec=True) def test_delete_saved_message_for_end_user(self, mock_db_session, factory): """Test deleting a saved message for an EndUser.""" # Arrange @@ -562,7 +562,7 @@ class TestSavedMessageServiceDelete: mock_db_session.delete.assert_called_once_with(saved_message) mock_db_session.commit.assert_called_once() - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.db.session", autospec=True) def test_delete_without_user_does_nothing(self, mock_db_session, factory): """Test that deleting without user is a no-op.""" # Arrange @@ -576,7 +576,7 @@ class TestSavedMessageServiceDelete: mock_db_session.delete.assert_not_called() mock_db_session.commit.assert_not_called() - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.db.session", autospec=True) def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory): """Test that deleting a non-existent saved message is a no-op.""" # Arrange @@ -597,7 +597,7 @@ class TestSavedMessageServiceDelete: mock_db_session.delete.assert_not_called() mock_db_session.commit.assert_not_called() - @patch("services.saved_message_service.db.session") + @patch("services.saved_message_service.db.session", autospec=True) def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory): """Test that delete only removes the user's own saved message.""" # Arrange diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py index 9494c0b211..264eac4d77 100644 --- a/api/tests/unit_tests/services/test_tag_service.py +++ b/api/tests/unit_tests/services/test_tag_service.py @@ -315,7 +315,7 @@ class TestTagServiceRetrieval: - get_tags_by_target_id: Get all tags bound to a specific target """ - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_get_tags_with_binding_counts(self, mock_db_session, factory): """ Test retrieving tags with their binding counts. @@ -372,7 +372,7 @@ class TestTagServiceRetrieval: # Verify database query was called mock_db_session.query.assert_called_once() - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_get_tags_with_keyword_filter(self, mock_db_session, factory): """ Test retrieving tags filtered by keyword (case-insensitive). @@ -426,7 +426,7 @@ class TestTagServiceRetrieval: # 2. Additional WHERE clause for keyword filtering assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): """ Test retrieving target IDs by tag IDs. @@ -482,7 +482,7 @@ class TestTagServiceRetrieval: # Verify both queries were executed assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): """ Test that empty tag_ids returns empty list. @@ -510,7 +510,7 @@ class TestTagServiceRetrieval: assert results == [], "Should return empty list for empty input" mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_get_tag_by_tag_name(self, mock_db_session, factory): """ Test retrieving tags by name. @@ -552,7 +552,7 @@ class TestTagServiceRetrieval: assert len(results) == 1, "Should find exactly one tag" assert results[0].name == tag_name, "Tag name should match" - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): """ Test that missing tag_type or tag_name returns empty list. @@ -580,7 +580,7 @@ class TestTagServiceRetrieval: # Verify no database queries were executed mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_get_tags_by_target_id(self, mock_db_session, factory): """ Test retrieving tags associated with a specific target. @@ -651,10 +651,10 @@ class TestTagServiceCRUD: - get_tag_binding_count: Get count of bindings for a tag """ - @patch("services.tag_service.current_user") - @patch("services.tag_service.TagService.get_tag_by_tag_name") - @patch("services.tag_service.db.session") - @patch("services.tag_service.uuid.uuid4") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) + @patch("services.tag_service.db.session", autospec=True) + @patch("services.tag_service.uuid.uuid4", autospec=True) def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): """ Test creating a new tag. @@ -709,8 +709,8 @@ class TestTagServiceCRUD: assert added_tag.created_by == "user-123", "Created by should match current user" assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" - @patch("services.tag_service.current_user") - @patch("services.tag_service.TagService.get_tag_by_tag_name") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory): """ Test that creating a tag with duplicate name raises ValueError. @@ -740,9 +740,9 @@ class TestTagServiceCRUD: with pytest.raises(ValueError, match="Tag name already exists"): TagService.save_tags(args) - @patch("services.tag_service.current_user") - @patch("services.tag_service.TagService.get_tag_by_tag_name") - @patch("services.tag_service.db.session") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): """ Test updating a tag name. @@ -792,9 +792,9 @@ class TestTagServiceCRUD: # Verify transaction was committed mock_db_session.commit.assert_called_once(), "Should commit transaction" - @patch("services.tag_service.current_user") - @patch("services.tag_service.TagService.get_tag_by_tag_name") - @patch("services.tag_service.db.session") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_update_tags_raises_error_for_duplicate_name( self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory ): @@ -826,7 +826,7 @@ class TestTagServiceCRUD: with pytest.raises(ValueError, match="Tag name already exists"): TagService.update_tags(args, tag_id="tag-123") - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): """ Test that updating a non-existent tag raises NotFound. @@ -848,8 +848,8 @@ class TestTagServiceCRUD: mock_query.first.return_value = None # Mock duplicate check and current_user - with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[]): - with patch("services.tag_service.current_user") as mock_user: + with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True): + with patch("services.tag_service.current_user", autospec=True) as mock_user: mock_user.current_tenant_id = "tenant-123" args = {"name": "New Name", "type": "app"} @@ -858,7 +858,7 @@ class TestTagServiceCRUD: with pytest.raises(NotFound, match="Tag not found"): TagService.update_tags(args, tag_id="nonexistent") - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_get_tag_binding_count(self, mock_db_session, factory): """ Test getting the count of bindings for a tag. @@ -894,7 +894,7 @@ class TestTagServiceCRUD: # Verify count matches expectation assert result == expected_count, "Binding count should match" - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_delete_tag(self, mock_db_session, factory): """ Test deleting a tag and its bindings. @@ -950,7 +950,7 @@ class TestTagServiceCRUD: # Verify transaction was committed mock_db_session.commit.assert_called_once(), "Should commit transaction" - @patch("services.tag_service.db.session") + @patch("services.tag_service.db.session", autospec=True) def test_delete_tag_raises_not_found(self, mock_db_session, factory): """ Test that deleting a non-existent tag raises NotFound. @@ -996,9 +996,9 @@ class TestTagServiceBindings: - check_target_exists: Validate target (dataset/app) existence """ - @patch("services.tag_service.current_user") - @patch("services.tag_service.TagService.check_target_exists") - @patch("services.tag_service.db.session") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.TagService.check_target_exists", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): """ Test creating tag bindings. @@ -1047,9 +1047,9 @@ class TestTagServiceBindings: # Verify transaction was committed mock_db_session.commit.assert_called_once(), "Should commit transaction" - @patch("services.tag_service.current_user") - @patch("services.tag_service.TagService.check_target_exists") - @patch("services.tag_service.db.session") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.TagService.check_target_exists", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): """ Test that saving duplicate bindings is idempotent. @@ -1088,8 +1088,8 @@ class TestTagServiceBindings: # Verify no new binding was added (idempotent) mock_db_session.add.assert_not_called(), "Should not create duplicate binding" - @patch("services.tag_service.TagService.check_target_exists") - @patch("services.tag_service.db.session") + @patch("services.tag_service.TagService.check_target_exists", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): """ Test deleting a tag binding. @@ -1136,8 +1136,8 @@ class TestTagServiceBindings: # Verify transaction was committed mock_db_session.commit.assert_called_once(), "Should commit transaction" - @patch("services.tag_service.TagService.check_target_exists") - @patch("services.tag_service.db.session") + @patch("services.tag_service.TagService.check_target_exists", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): """ Test that deleting a non-existent binding is a no-op. @@ -1173,8 +1173,8 @@ class TestTagServiceBindings: # Verify no commit was made (nothing changed) mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" - @patch("services.tag_service.current_user") - @patch("services.tag_service.db.session") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): """ Test validating that a dataset target exists. @@ -1214,8 +1214,8 @@ class TestTagServiceBindings: # Verify no exception was raised and query was executed mock_db_session.query.assert_called_once(), "Should query database for dataset" - @patch("services.tag_service.current_user") - @patch("services.tag_service.db.session") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): """ Test validating that an app target exists. @@ -1255,8 +1255,8 @@ class TestTagServiceBindings: # Verify no exception was raised and query was executed mock_db_session.query.assert_called_once(), "Should query database for app" - @patch("services.tag_service.current_user") - @patch("services.tag_service.db.session") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_check_target_exists_raises_not_found_for_missing_dataset( self, mock_db_session, mock_current_user, factory ): @@ -1287,8 +1287,8 @@ class TestTagServiceBindings: with pytest.raises(NotFound, match="Dataset not found"): TagService.check_target_exists("knowledge", "nonexistent") - @patch("services.tag_service.current_user") - @patch("services.tag_service.db.session") + @patch("services.tag_service.current_user", autospec=True) + @patch("services.tag_service.db.session", autospec=True) def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): """ Test that missing app raises NotFound. diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 4534e68b4e..8199d586da 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -17,7 +17,9 @@ from uuid import uuid4 import pytest -from core.variables.segments import ( +from core.workflow.file.enums import FileTransferMethod, FileType +from core.workflow.file.models import File +from core.workflow.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, ArraySegment, @@ -28,8 +30,6 @@ from core.variables.segments import ( ObjectSegment, StringSegment, ) -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File from services.variable_truncator import ( DummyVariableTruncator, MaxDepthExceededError, diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index d788657589..ffdcc046f9 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -87,7 +87,7 @@ class TestWebhookServiceUnit: webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" - with patch.object(WebhookService, "_process_file_uploads") as mock_process_files: + with patch.object(WebhookService, "_process_file_uploads", autospec=True) as mock_process_files: mock_process_files.return_value = {"file": "mocked_file_obj"} webhook_data = WebhookService.extract_webhook_data(webhook_trigger) @@ -123,8 +123,10 @@ class TestWebhookServiceUnit: mock_file.to_dict.return_value = {"file": "data"} with ( - patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect, - patch.object(WebhookService, "_create_file_from_binary") as mock_create, + patch.object( + WebhookService, "_detect_binary_mimetype", return_value="text/plain", autospec=True + ) as mock_detect, + patch.object(WebhookService, "_create_file_from_binary", autospec=True) as mock_create, ): mock_create.return_value = mock_file body, files = WebhookService._extract_octet_stream_body(webhook_trigger) @@ -168,7 +170,7 @@ class TestWebhookServiceUnit: fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error") monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) - with patch("services.trigger.webhook_service.logger") as mock_logger: + with patch("services.trigger.webhook_service.logger", autospec=True) as mock_logger: result = WebhookService._detect_binary_mimetype(b"binary data") assert result == "application/octet-stream" @@ -245,15 +247,12 @@ class TestWebhookServiceUnit: assert response_data[0]["id"] == 1 assert response_data[1]["id"] == 2 - @patch("services.trigger.webhook_service.ToolFileManager") - @patch("services.trigger.webhook_service.file_factory") + @patch("services.trigger.webhook_service.ToolFileManager", autospec=True) + @patch("services.trigger.webhook_service.file_factory", autospec=True) def test_process_file_uploads_success(self, mock_file_factory, mock_tool_file_manager): """Test successful file upload processing.""" # Mock ToolFileManager - mock_tool_file_instance = MagicMock() - mock_tool_file_manager.return_value = mock_tool_file_instance - - # Mock file creation + mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation mock_tool_file = MagicMock() mock_tool_file.id = "test_file_id" mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file @@ -285,15 +284,12 @@ class TestWebhookServiceUnit: assert mock_tool_file_manager.call_count == 2 assert mock_file_factory.build_from_mapping.call_count == 2 - @patch("services.trigger.webhook_service.ToolFileManager") - @patch("services.trigger.webhook_service.file_factory") + @patch("services.trigger.webhook_service.ToolFileManager", autospec=True) + @patch("services.trigger.webhook_service.file_factory", autospec=True) def test_process_file_uploads_with_errors(self, mock_file_factory, mock_tool_file_manager): """Test file upload processing with errors.""" # Mock ToolFileManager - mock_tool_file_instance = MagicMock() - mock_tool_file_manager.return_value = mock_tool_file_instance - - # Mock file creation + mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation mock_tool_file = MagicMock() mock_tool_file.id = "test_file_id" mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file @@ -544,8 +540,8 @@ class TestWebhookServiceUnit: # Mock the WebhookService methods with ( - patch.object(WebhookService, "get_webhook_trigger_and_workflow") as mock_get_trigger, - patch.object(WebhookService, "extract_and_validate_webhook_data") as mock_extract, + patch.object(WebhookService, "get_webhook_trigger_and_workflow", autospec=True) as mock_get_trigger, + patch.object(WebhookService, "extract_and_validate_webhook_data", autospec=True) as mock_extract, ): mock_trigger = MagicMock() mock_workflow = MagicMock() diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index ded141f01a..1f92ff590c 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -124,7 +124,7 @@ class TestWorkflowRunService: """Create WorkflowRunService instance with mocked dependencies.""" session_factory, _ = mock_session_factory - with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory: + with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory: mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository service = WorkflowRunService(session_factory) return service @@ -135,7 +135,7 @@ class TestWorkflowRunService: mock_engine = create_autospec(Engine) session_factory, _ = mock_session_factory - with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory: + with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory: mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository service = WorkflowRunService(mock_engine) return service @@ -146,7 +146,7 @@ class TestWorkflowRunService: """Test WorkflowRunService initialization with session_factory.""" session_factory, _ = mock_session_factory - with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory: + with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory: mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository service = WorkflowRunService(session_factory) @@ -158,9 +158,11 @@ class TestWorkflowRunService: mock_engine = create_autospec(Engine) session_factory, _ = mock_session_factory - with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory: + with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory: mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository - with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker: + with patch( + "services.workflow_run_service.sessionmaker", return_value=session_factory, autospec=True + ) as mock_sessionmaker: service = WorkflowRunService(mock_engine) mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False) diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index 6e03472b9d..83642fc209 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -6,8 +6,8 @@ from unittest.mock import Mock, patch import pytest from sqlalchemy import Engine -from core.variables.segments import ObjectSegment, StringSegment -from core.variables.types import SegmentType +from core.workflow.variables.segments import ObjectSegment, StringSegment +from core.workflow.variables.types import SegmentType from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader @@ -174,7 +174,7 @@ class TestDraftVarLoaderSimple: mock_storage.load.return_value = test_json_content.encode() with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from core.variables.segments import FloatSegment + from core.workflow.variables.segments import FloatSegment mock_segment = FloatSegment(value=test_number) mock_build_segment.return_value = mock_segment @@ -224,7 +224,7 @@ class TestDraftVarLoaderSimple: mock_storage.load.return_value = test_json_content.encode() with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from core.variables.segments import ArrayAnySegment + from core.workflow.variables.segments import ArrayAnySegment mock_segment = ArrayAnySegment(value=test_array) mock_build_segment.return_value = mock_segment diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 66361f26e0..792257848f 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -7,10 +7,10 @@ import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session -from core.variables.segments import StringSegment -from core.variables.types import SegmentType from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.enums import NodeType +from core.workflow.variables.segments import StringSegment +from core.workflow.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType @@ -141,7 +141,7 @@ class TestDraftVariableSaver: def test_draft_saver_with_small_variables(self, draft_saver, mock_session): with patch( - "services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable" + "services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True ) as _mock_try_offload: _mock_try_offload.return_value = None mock_segment = StringSegment(value="small value") @@ -153,7 +153,7 @@ class TestDraftVariableSaver: def test_draft_saver_with_large_variables(self, draft_saver, mock_session): with patch( - "services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable" + "services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True ) as _mock_try_offload: mock_segment = StringSegment(value="small value") mock_draft_var_file = WorkflowDraftVariableFile( @@ -170,7 +170,7 @@ class TestDraftVariableSaver: # Should not have large variable metadata assert draft_var.file_id == mock_draft_var_file.id - @patch("services.workflow_draft_variable_service._batch_upsert_draft_variable") + @patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True) def test_save_method_integration(self, mock_batch_upsert, draft_saver): """Test complete save workflow.""" outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}} @@ -222,7 +222,7 @@ class TestWorkflowDraftVariableService: name="test_var", value=StringSegment(value="reset_value"), ) - with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv: + with patch.object(service, "_reset_conv_var", return_value=expected_result, autospec=True) as mock_reset_conv: result = service.reset_variable(workflow, variable) mock_reset_conv.assert_called_once_with(workflow, variable) @@ -330,8 +330,8 @@ class TestWorkflowDraftVariableService: # Mock workflow methods mock_node_config = {"type": "test_node"} with ( - patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config), - patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM), + patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config, autospec=True), + patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM, autospec=True), ): result = service._reset_node_var_or_sys_var(workflow, variable) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index c96c8cf09d..df33f20c9b 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -50,7 +50,7 @@ def pipeline_id(): @pytest.fixture def mock_db_session(): """Mock database session via session_factory.create_session().""" - with patch("tasks.clean_dataset_task.session_factory") as mock_sf: + with patch("tasks.clean_dataset_task.session_factory", autospec=True) as mock_sf: mock_session = MagicMock() # context manager for create_session() cm = MagicMock() @@ -79,7 +79,7 @@ def mock_db_session(): @pytest.fixture def mock_storage(): """Mock storage client.""" - with patch("tasks.clean_dataset_task.storage") as mock_storage: + with patch("tasks.clean_dataset_task.storage", autospec=True) as mock_storage: mock_storage.delete.return_value = None yield mock_storage @@ -87,7 +87,7 @@ def mock_storage(): @pytest.fixture def mock_index_processor_factory(): """Mock IndexProcessorFactory.""" - with patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_factory: + with patch("tasks.clean_dataset_task.IndexProcessorFactory", autospec=True) as mock_factory: mock_processor = MagicMock() mock_processor.clean.return_value = None mock_factory_instance = MagicMock() @@ -104,7 +104,7 @@ def mock_index_processor_factory(): @pytest.fixture def mock_get_image_upload_file_ids(): """Mock get_image_upload_file_ids function.""" - with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_func: + with patch("tasks.clean_dataset_task.get_image_upload_file_ids", autospec=True) as mock_func: mock_func.return_value = [] yield mock_func diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 549f2c6c9b..a68aad7606 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -1,12 +1,8 @@ """ -Unit tests for document indexing sync task. +Unit tests for collaborator parameter wiring in document_indexing_sync_task. -This module tests the document indexing sync task functionality including: -- Syncing Notion documents when updated -- Validating document and data source existence -- Credential validation and retrieval -- Cleaning old segments before re-indexing -- Error handling and edge cases +These tests intentionally stay in unit scope because they validate call arguments +for external collaborators rather than SQL-backed state transitions. """ import uuid @@ -14,188 +10,93 @@ from unittest.mock import MagicMock, Mock, patch import pytest -from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document from tasks.document_indexing_sync_task import document_indexing_sync_task -# ============================================================================ -# Fixtures -# ============================================================================ - @pytest.fixture -def tenant_id(): - """Generate a unique tenant ID for testing.""" +def dataset_id() -> str: + """Generate a dataset id.""" return str(uuid.uuid4()) @pytest.fixture -def dataset_id(): - """Generate a unique dataset ID for testing.""" +def document_id() -> str: + """Generate a document id.""" return str(uuid.uuid4()) @pytest.fixture -def document_id(): - """Generate a unique document ID for testing.""" +def notion_workspace_id() -> str: + """Generate a notion workspace id.""" return str(uuid.uuid4()) @pytest.fixture -def notion_workspace_id(): - """Generate a Notion workspace ID for testing.""" +def notion_page_id() -> str: + """Generate a notion page id.""" return str(uuid.uuid4()) @pytest.fixture -def notion_page_id(): - """Generate a Notion page ID for testing.""" +def credential_id() -> str: + """Generate a credential id.""" return str(uuid.uuid4()) @pytest.fixture -def credential_id(): - """Generate a credential ID for testing.""" - return str(uuid.uuid4()) - - -@pytest.fixture -def mock_dataset(dataset_id, tenant_id): - """Create a mock Dataset object.""" +def mock_dataset(dataset_id): + """Create a minimal dataset mock used by the task pre-check.""" dataset = Mock(spec=Dataset) dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.indexing_technique = "high_quality" - dataset.embedding_model_provider = "openai" - dataset.embedding_model = "text-embedding-ada-002" return dataset @pytest.fixture -def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id): - """Create a mock Document object with Notion data source.""" - doc = Mock(spec=Document) - doc.id = document_id - doc.dataset_id = dataset_id - doc.tenant_id = tenant_id - doc.data_source_type = "notion_import" - doc.indexing_status = "completed" - doc.error = None - doc.stopped_at = None - doc.processing_started_at = None - doc.doc_form = "text_model" - doc.data_source_info_dict = { +def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, credential_id): + """Create a minimal notion document mock for collaborator parameter assertions.""" + document = Mock(spec=Document) + document.id = document_id + document.dataset_id = dataset_id + document.tenant_id = str(uuid.uuid4()) + document.data_source_type = "notion_import" + document.indexing_status = "completed" + document.doc_form = "text_model" + document.data_source_info_dict = { "notion_workspace_id": notion_workspace_id, "notion_page_id": notion_page_id, "type": "page", "last_edited_time": "2024-01-01T00:00:00Z", "credential_id": credential_id, } - return doc + return document @pytest.fixture -def mock_document_segments(document_id): - """Create mock DocumentSegment objects.""" - segments = [] - for i in range(3): - segment = Mock(spec=DocumentSegment) - segment.id = str(uuid.uuid4()) - segment.document_id = document_id - segment.index_node_id = f"node-{document_id}-{i}" - segments.append(segment) - return segments +def mock_db_session(mock_document, mock_dataset): + """Mock session_factory.create_session to drive deterministic read-only task flow.""" + with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory: + session = MagicMock() + session.scalars.return_value.all.return_value = [] + session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + begin_cm = MagicMock() + begin_cm.__enter__.return_value = session + begin_cm.__exit__.return_value = False + session.begin.return_value = begin_cm -@pytest.fixture -def mock_db_session(): - """Mock database session via session_factory.create_session(). + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False - After session split refactor, the code calls create_session() multiple times. - This fixture creates shared query mocks so all sessions use the same - query configuration, simulating database persistence across sessions. - - The fixture automatically converts side_effect to cycle to prevent StopIteration. - Tests configure mocks the same way as before, but behind the scenes the values - are cycled infinitely for all sessions. - """ - from itertools import cycle - - with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf: - sessions = [] - - # Shared query mocks - all sessions use these - shared_query = MagicMock() - shared_filter_by = MagicMock() - shared_scalars_result = MagicMock() - - # Create custom first mock that auto-cycles side_effect - class CyclicMock(MagicMock): - def __setattr__(self, name, value): - if name == "side_effect" and value is not None: - # Convert list/tuple to infinite cycle - if isinstance(value, (list, tuple)): - value = cycle(value) - super().__setattr__(name, value) - - shared_query.where.return_value.first = CyclicMock() - shared_filter_by.first = CyclicMock() - - def _create_session(): - """Create a new mock session for each create_session() call.""" - session = MagicMock() - session.close = MagicMock() - session.commit = MagicMock() - - # Mock session.begin() context manager - begin_cm = MagicMock() - begin_cm.__enter__.return_value = session - - def _begin_exit_side_effect(exc_type, exc, tb): - # commit on success - if exc_type is None: - session.commit() - # return False to propagate exceptions - return False - - begin_cm.__exit__.side_effect = _begin_exit_side_effect - session.begin.return_value = begin_cm - - # Mock create_session() context manager - cm = MagicMock() - cm.__enter__.return_value = session - - def _exit_side_effect(exc_type, exc, tb): - session.close() - return False - - cm.__exit__.side_effect = _exit_side_effect - - # All sessions use the same shared query mocks - session.query.return_value = shared_query - shared_query.where.return_value = shared_query - shared_query.filter_by.return_value = shared_filter_by - session.scalars.return_value = shared_scalars_result - - sessions.append(session) - # Attach helpers on the first created session for assertions across all sessions - if len(sessions) == 1: - session.get_all_sessions = lambda: sessions - session.any_close_called = lambda: any(s.close.called for s in sessions) - session.any_commit_called = lambda: any(s.commit.called for s in sessions) - return cm - - mock_sf.create_session.side_effect = _create_session - - # Create first session and return it - _create_session() - yield sessions[0] + mock_session_factory.create_session.return_value = session_cm + yield session @pytest.fixture def mock_datasource_provider_service(): - """Mock DatasourceProviderService.""" - with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class: + """Mock datasource credential provider.""" + with patch("tasks.document_indexing_sync_task.DatasourceProviderService", autospec=True) as mock_service_class: mock_service = MagicMock() mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"} mock_service_class.return_value = mock_service @@ -204,314 +105,16 @@ def mock_datasource_provider_service(): @pytest.fixture def mock_notion_extractor(): - """Mock NotionExtractor.""" - with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class: + """Mock notion extractor class and instance.""" + with patch("tasks.document_indexing_sync_task.NotionExtractor", autospec=True) as mock_extractor_class: mock_extractor = MagicMock() - mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time + mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" mock_extractor_class.return_value = mock_extractor - yield mock_extractor + yield {"class": mock_extractor_class, "instance": mock_extractor} -@pytest.fixture -def mock_index_processor_factory(): - """Mock IndexProcessorFactory.""" - with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory: - mock_processor = MagicMock() - mock_processor.clean = Mock() - mock_factory.return_value.init_index_processor.return_value = mock_processor - yield mock_factory - - -@pytest.fixture -def mock_indexing_runner(): - """Mock IndexingRunner.""" - with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class: - mock_runner = MagicMock(spec=IndexingRunner) - mock_runner.run = Mock() - mock_runner_class.return_value = mock_runner - yield mock_runner - - -# ============================================================================ -# Tests for document_indexing_sync_task -# ============================================================================ - - -class TestDocumentIndexingSyncTask: - """Tests for the document_indexing_sync_task function.""" - - def test_document_not_found(self, mock_db_session, dataset_id, document_id): - """Test that task handles document not found gracefully.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = None - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - at least one session should have been closed - assert mock_db_session.any_close_called() - - def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id): - """Test that task raises error when notion_workspace_id is missing.""" - # Arrange - mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"} - mock_db_session.query.return_value.where.return_value.first.return_value = mock_document - - # Act & Assert - with pytest.raises(ValueError, match="no notion page found"): - document_indexing_sync_task(dataset_id, document_id) - - def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id): - """Test that task raises error when notion_page_id is missing.""" - # Arrange - mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"} - mock_db_session.query.return_value.where.return_value.first.return_value = mock_document - - # Act & Assert - with pytest.raises(ValueError, match="no notion page found"): - document_indexing_sync_task(dataset_id, document_id) - - def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id): - """Test that task raises error when data_source_info is empty.""" - # Arrange - mock_document.data_source_info_dict = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_document - - # Act & Assert - with pytest.raises(ValueError, match="no notion page found"): - document_indexing_sync_task(dataset_id, document_id) - - def test_credential_not_found( - self, - mock_db_session, - mock_datasource_provider_service, - mock_document, - dataset_id, - document_id, - ): - """Test that task handles missing credentials by updating document status.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_document - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document - mock_datasource_provider_service.get_datasource_credentials.return_value = None - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - assert mock_document.indexing_status == "error" - assert "Datasource credential not found" in mock_document.error - assert mock_document.stopped_at is not None - assert mock_db_session.any_commit_called() - assert mock_db_session.any_close_called() - - def test_page_not_updated( - self, - mock_db_session, - mock_datasource_provider_service, - mock_notion_extractor, - mock_document, - dataset_id, - document_id, - ): - """Test that task does nothing when page has not been updated.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_document - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document - # Return same time as stored in document - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - # Document status should remain unchanged - assert mock_document.indexing_status == "completed" - # At least one session should have been closed via context manager teardown - assert mock_db_session.any_close_called() - - def test_successful_sync_when_page_updated( - self, - mock_db_session, - mock_datasource_provider_service, - mock_notion_extractor, - mock_index_processor_factory, - mock_indexing_runner, - mock_dataset, - mock_document, - mock_document_segments, - dataset_id, - document_id, - ): - """Test successful sync flow when Notion page has been updated.""" - # Arrange - # Set exact sequence of returns across calls to `.first()`: - # 1) document (initial fetch) - # 2) dataset (pre-check) - # 3) dataset (cleaning phase) - # 4) document (pre-indexing update) - # 5) document (indexing runner fetch) - mock_db_session.query.return_value.where.return_value.first.side_effect = [ - mock_document, - mock_dataset, - mock_dataset, - mock_document, - mock_document, - ] - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document - mock_db_session.scalars.return_value.all.return_value = mock_document_segments - # NotionExtractor returns updated time - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - # Verify document status was updated to parsing - assert mock_document.indexing_status == "parsing" - assert mock_document.processing_started_at is not None - - # Verify segments were cleaned - mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - mock_processor.clean.assert_called_once() - - # Verify segments were deleted from database in batch (DELETE FROM document_segments) - # Aggregate execute calls across all created sessions - execute_sqls = [] - for s in mock_db_session.get_all_sessions(): - execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list]) - assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) - - # Verify indexing runner was called - mock_indexing_runner.run.assert_called_once_with([mock_document]) - - # Verify session operations (across any created session) - assert mock_db_session.any_commit_called() - assert mock_db_session.any_close_called() - - def test_dataset_not_found_during_cleaning( - self, - mock_db_session, - mock_datasource_provider_service, - mock_notion_extractor, - mock_indexing_runner, - mock_document, - dataset_id, - document_id, - ): - """Test that task handles dataset not found during cleaning phase.""" - # Arrange - # Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing) - mock_db_session.query.return_value.where.return_value.first.side_effect = [ - mock_document, - mock_dataset, - None, - mock_document, - mock_document, - ] - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - # Document should still be set to parsing - assert mock_document.indexing_status == "parsing" - # At least one session should be closed after error - assert mock_db_session.any_close_called() - - def test_cleaning_error_continues_to_indexing( - self, - mock_db_session, - mock_datasource_provider_service, - mock_notion_extractor, - mock_index_processor_factory, - mock_indexing_runner, - mock_dataset, - mock_document, - dataset_id, - document_id, - ): - """Test that indexing continues even if cleaning fails.""" - # Arrange - from itertools import cycle - - mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset]) - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document - # Make the cleaning step fail but not the segment fetch - processor = mock_index_processor_factory.return_value.init_index_processor.return_value - processor.clean.side_effect = Exception("Cleaning error") - mock_db_session.scalars.return_value.all.return_value = [] - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - # Indexing should still be attempted despite cleaning error - mock_indexing_runner.run.assert_called_once_with([mock_document]) - assert mock_db_session.any_close_called() - - def test_indexing_runner_document_paused_error( - self, - mock_db_session, - mock_datasource_provider_service, - mock_notion_extractor, - mock_index_processor_factory, - mock_indexing_runner, - mock_dataset, - mock_document, - mock_document_segments, - dataset_id, - document_id, - ): - """Test that DocumentIsPausedError is handled gracefully.""" - # Arrange - from itertools import cycle - - mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset]) - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document - mock_db_session.scalars.return_value.all.return_value = mock_document_segments - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" - mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - # Session should be closed after handling error - assert mock_db_session.any_close_called() - - def test_indexing_runner_general_error( - self, - mock_db_session, - mock_datasource_provider_service, - mock_notion_extractor, - mock_index_processor_factory, - mock_indexing_runner, - mock_dataset, - mock_document, - mock_document_segments, - dataset_id, - document_id, - ): - """Test that general exceptions during indexing are handled.""" - # Arrange - from itertools import cycle - - mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset]) - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document - mock_db_session.scalars.return_value.all.return_value = mock_document_segments - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" - mock_indexing_runner.run.side_effect = Exception("Indexing error") - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - # Session should be closed after error - assert mock_db_session.any_close_called() +class TestDocumentIndexingSyncTaskCollaboratorParams: + """Unit tests for collaborator parameter passing in document_indexing_sync_task.""" def test_notion_extractor_initialized_with_correct_params( self, @@ -524,27 +127,21 @@ class TestDocumentIndexingSyncTask: notion_workspace_id, notion_page_id, ): - """Test that NotionExtractor is initialized with correct parameters.""" + """Test that NotionExtractor is initialized with expected arguments.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_document - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update + expected_token = "test_token" # Act - with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class: - mock_extractor = MagicMock() - mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" - mock_extractor_class.return_value = mock_extractor + document_indexing_sync_task(dataset_id, document_id) - document_indexing_sync_task(dataset_id, document_id) - - # Assert - mock_extractor_class.assert_called_once_with( - notion_workspace_id=notion_workspace_id, - notion_obj_id=notion_page_id, - notion_page_type="page", - notion_access_token="test_token", - tenant_id=mock_document.tenant_id, - ) + # Assert + mock_notion_extractor["class"].assert_called_once_with( + notion_workspace_id=notion_workspace_id, + notion_obj_id=notion_page_id, + notion_page_type="page", + notion_access_token=expected_token, + tenant_id=mock_document.tenant_id, + ) def test_datasource_credentials_requested_correctly( self, @@ -556,17 +153,16 @@ class TestDocumentIndexingSyncTask: document_id, credential_id, ): - """Test that datasource credentials are requested with correct parameters.""" + """Test that datasource credentials are requested with expected identifiers.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_document - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" + expected_tenant_id = mock_document.tenant_id # Act document_indexing_sync_task(dataset_id, document_id) # Assert mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with( - tenant_id=mock_document.tenant_id, + tenant_id=expected_tenant_id, credential_id=credential_id, provider="notion_datasource", plugin_id="langgenius/notion_datasource", @@ -581,16 +177,14 @@ class TestDocumentIndexingSyncTask: dataset_id, document_id, ): - """Test that task handles missing credential_id by passing None.""" + """Test that missing credential_id is forwarded as None.""" # Arrange mock_document.data_source_info_dict = { - "notion_workspace_id": "ws123", - "notion_page_id": "page123", + "notion_workspace_id": "workspace-id", + "notion_page_id": "page-id", "type": "page", "last_edited_time": "2024-01-01T00:00:00Z", } - mock_db_session.query.return_value.where.return_value.first.return_value = mock_document - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # Act document_indexing_sync_task(dataset_id, document_id) @@ -602,39 +196,3 @@ class TestDocumentIndexingSyncTask: provider="notion_datasource", plugin_id="langgenius/notion_datasource", ) - - def test_index_processor_clean_called_with_correct_params( - self, - mock_db_session, - mock_datasource_provider_service, - mock_notion_extractor, - mock_index_processor_factory, - mock_indexing_runner, - mock_dataset, - mock_document, - mock_document_segments, - dataset_id, - document_id, - ): - """Test that index processor clean is called with correct parameters.""" - # Arrange - # Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing) - mock_db_session.query.return_value.where.return_value.first.side_effect = [ - mock_document, - mock_dataset, - mock_dataset, - mock_document, - mock_document, - ] - mock_db_session.scalars.return_value.all.return_value = mock_document_segments - mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" - - # Act - document_indexing_sync_task(dataset_id, document_id) - - # Assert - mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - expected_node_ids = [seg.index_node_id for seg in mock_document_segments] - mock_processor.clean.assert_called_once_with( - mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True - ) diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py index 8a4c6da2e9..68fb8b748f 100644 --- a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -95,7 +95,7 @@ def mock_document_segments(document_ids): @pytest.fixture def mock_db_session(): """Mock database session via session_factory.create_session().""" - with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf: + with patch("tasks.duplicate_document_indexing_task.session_factory", autospec=True) as mock_sf: session = MagicMock() # Allow tests to observe session.close() via context manager teardown session.close = MagicMock() @@ -118,7 +118,7 @@ def mock_db_session(): @pytest.fixture def mock_indexing_runner(): """Mock IndexingRunner.""" - with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class: + with patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_runner_class: mock_runner = MagicMock(spec=IndexingRunner) mock_runner_class.return_value = mock_runner yield mock_runner @@ -127,7 +127,7 @@ def mock_indexing_runner(): @pytest.fixture def mock_feature_service(): """Mock FeatureService.""" - with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service: + with patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_service: mock_features = Mock() mock_features.billing = Mock() mock_features.billing.enabled = False @@ -141,7 +141,7 @@ def mock_feature_service(): @pytest.fixture def mock_index_processor_factory(): """Mock IndexProcessorFactory.""" - with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory: + with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True) as mock_factory: mock_processor = MagicMock() mock_processor.clean = Mock() mock_factory.return_value.init_index_processor.return_value = mock_processor @@ -151,7 +151,7 @@ def mock_index_processor_factory(): @pytest.fixture def mock_tenant_isolated_queue(): """Mock TenantIsolatedTaskQueue.""" - with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class: + with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) as mock_queue_class: mock_queue = MagicMock(spec=TenantIsolatedTaskQueue) mock_queue.pull_tasks.return_value = [] mock_queue.delete_task_key = Mock() @@ -168,7 +168,7 @@ def mock_tenant_isolated_queue(): class TestDuplicateDocumentIndexingTask: """Tests for the deprecated duplicate_document_indexing_task function.""" - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True) def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids): """Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function.""" # Act @@ -177,7 +177,7 @@ class TestDuplicateDocumentIndexingTask: # Assert mock_core_func.assert_called_once_with(dataset_id, document_ids) - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True) def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id): """Test duplicate_document_indexing_task with empty document_ids list.""" # Arrange @@ -445,7 +445,7 @@ class TestDuplicateDocumentIndexingTaskCore: class TestDuplicateDocumentIndexingTaskWithTenantQueue: """Tests for _duplicate_document_indexing_task_with_tenant_queue function.""" - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True) def test_tenant_queue_wrapper_calls_core_function( self, mock_core_func, @@ -464,7 +464,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue: # Assert mock_core_func.assert_called_once_with(dataset_id, document_ids) - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True) def test_tenant_queue_wrapper_deletes_key_when_no_tasks( self, mock_core_func, @@ -484,7 +484,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue: # Assert mock_tenant_isolated_queue.delete_task_key.assert_called_once() - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( self, mock_core_func, @@ -514,7 +514,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue: document_ids=document_ids, ) - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True) def test_tenant_queue_wrapper_handles_core_function_error( self, mock_core_func, @@ -544,7 +544,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue: class TestNormalDuplicateDocumentIndexingTask: """Tests for normal_duplicate_document_indexing_task function.""" - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True) def test_normal_task_calls_tenant_queue_wrapper( self, mock_wrapper_func, @@ -561,7 +561,7 @@ class TestNormalDuplicateDocumentIndexingTask: tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task ) - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True) def test_normal_task_with_empty_document_ids( self, mock_wrapper_func, @@ -589,7 +589,7 @@ class TestNormalDuplicateDocumentIndexingTask: class TestPriorityDuplicateDocumentIndexingTask: """Tests for priority_duplicate_document_indexing_task function.""" - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True) def test_priority_task_calls_tenant_queue_wrapper( self, mock_wrapper_func, @@ -606,7 +606,7 @@ class TestPriorityDuplicateDocumentIndexingTask: tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task ) - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True) def test_priority_task_with_single_document( self, mock_wrapper_func, @@ -625,7 +625,7 @@ class TestPriorityDuplicateDocumentIndexingTask: tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task ) - @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue") + @patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True) def test_priority_task_with_large_batch( self, mock_wrapper_func, diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 9046f785d2..9a0dbfa2d8 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -321,7 +321,9 @@ def test_structured_output_parser(): ) else: # Test successful cases - with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair: + with patch( + "core.llm_generator.output_parser.structured_output.json_repair.loads", autospec=True + ) as mock_json_repair: # Configure json_repair mock for cases that need it if case["name"] == "json_repair_scenario": mock_json_repair.return_value = {"name": "test"} @@ -402,7 +404,9 @@ def test_parse_structured_output_edge_cases(): prompt_messages = [UserPromptMessage(content="Test reasoning")] - with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair: + with patch( + "core.llm_generator.output_parser.structured_output.json_repair.loads", autospec=True + ) as mock_json_repair: # Mock json_repair to return a list with dict mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"] diff --git a/api/uv.lock b/api/uv.lock index e1e2ac8651..79886ca9a7 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -5049,11 +5049,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.7.1" +version = "6.7.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ff/63/3437c4363483f2a04000a48f1cd48c40097f69d580363712fa8b0b4afe45/pypdf-6.7.1.tar.gz", hash = "sha256:6b7a63be5563a0a35d54c6d6b550d75c00b8ccf36384be96365355e296e6b3b0", size = 5302208, upload-time = "2026-02-17T17:00:48.88Z" } +sdist = { url = "https://files.pythonhosted.org/packages/09/dc/f52deef12797ad58b88e4663f097a343f53b9361338aef6573f135ac302f/pypdf-6.7.4.tar.gz", hash = "sha256:9edd1cd47938bb35ec87795f61225fd58a07cfaf0c5699018ae1a47d6f8ab0e3", size = 5304821, upload-time = "2026-02-27T10:44:39.395Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/77/38bd7744bb9e06d465b0c23879e6d2c187d93a383f8fa485c862822bb8a3/pypdf-6.7.1-py3-none-any.whl", hash = "sha256:a02ccbb06463f7c334ce1612e91b3e68a8e827f3cee100b9941771e6066b094e", size = 331048, upload-time = "2026-02-17T17:00:46.991Z" }, + { url = "https://files.pythonhosted.org/packages/c1/be/cded021305f5c81b47265b8c5292b99388615a4391c21ff00fd538d34a56/pypdf-6.7.4-py3-none-any.whl", hash = "sha256:527d6da23274a6c70a9cb59d1986d93946ba8e36a6bc17f3f7cce86331492dda", size = 331496, upload-time = "2026-02-27T10:44:37.527Z" }, ] [[package]] diff --git a/web/README.md b/web/README.md index 64039709dc..f069ec82b2 100644 --- a/web/README.md +++ b/web/README.md @@ -33,7 +33,7 @@ Then, configure the environment variables. Create a file named `.env.local` in t cp .env.example .env.local ``` -``` +```txt # For production release, change this to PRODUCTION NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT # The deployment edition, SELF_HOSTED diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 1c5434924f..4f3f724e62 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -28,13 +28,13 @@ import { cn } from '@/utils/classnames' export type IAppDetailLayoutProps = { children: React.ReactNode - params: { datasetId: string } + datasetId: string } const DatasetDetailLayout: FC = (props) => { const { children, - params: { datasetId }, + datasetId, } = props const { t } = useTranslation() const pathname = usePathname() diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index a8772f7cfd..64f3df1669 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -6,12 +6,11 @@ const DatasetDetailLayout = async ( params: Promise<{ datasetId: string }> }, ) => { - const params = await props.params - const { children, + params, } = props - return
{children}
+ return
{children}
} export default DatasetDetailLayout diff --git a/web/app/account/oauth/authorize/constants.ts b/web/app/account/oauth/authorize/constants.ts deleted file mode 100644 index f1d8b98ef4..0000000000 --- a/web/app/account/oauth/authorize/constants.ts +++ /dev/null @@ -1,3 +0,0 @@ -export const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending' -export const REDIRECT_URL_KEY = 'oauth_redirect_url' -export const OAUTH_AUTHORIZE_PENDING_TTL = 60 * 3 diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index c923d6457a..d718e0941d 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -7,7 +7,6 @@ import { RiMailLine, RiTranslate2, } from '@remixicon/react' -import dayjs from 'dayjs' import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useEffect, useRef } from 'react' @@ -17,22 +16,10 @@ import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' import { useAppContext } from '@/context/app-context' import { useIsLogin } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' -import { - OAUTH_AUTHORIZE_PENDING_KEY, - OAUTH_AUTHORIZE_PENDING_TTL, - REDIRECT_URL_KEY, -} from './constants' - -function setItemWithExpiry(key: string, value: string, ttl: number) { - const item = { - value, - expiry: dayjs().add(ttl, 'seconds').unix(), - } - localStorage.setItem(key, JSON.stringify(item)) -} function buildReturnUrl(pathname: string, search: string) { try { @@ -86,8 +73,8 @@ export default function OAuthAuthorize() { const onLoginSwitchClick = () => { try { const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`) - setItemWithExpiry(OAUTH_AUTHORIZE_PENDING_KEY, returnUrl, OAUTH_AUTHORIZE_PENDING_TTL) - router.push(`/signin?${REDIRECT_URL_KEY}=${encodeURIComponent(returnUrl)}`) + setPostLoginRedirect(returnUrl) + router.push('/signin') } catch { router.push('/signin') @@ -145,7 +132,7 @@ export default function OAuthAuthorize() {
{authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })}
{!isLoggedIn &&
{t('tips.notLoggedIn', { ns: 'oauth' })}
} -
{isLoggedIn ? `${authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })} ${t('tips.loggedIn', { ns: 'oauth' })}` : t('tips.needLogin', { ns: 'oauth' })}
+
{isLoggedIn ? `${authAppInfo?.app_label[language] || authAppInfo?.app_label?.en_US || t('unknownApp', { ns: 'oauth' })} ${t('tips.loggedIn', { ns: 'oauth' })}` : t('tips.needLogin', { ns: 'oauth' })}
{isLoggedIn && userProfile && ( @@ -154,7 +141,7 @@ export default function OAuthAuthorize() {
{userProfile.name}
-
{userProfile.email}
+
{userProfile.email}
@@ -166,7 +153,7 @@ export default function OAuthAuthorize() { {authAppInfo!.scope.split(/\s+/).filter(Boolean).map((scope: string) => { const Icon = SCOPE_INFO_MAP[scope] return ( -
+
{Icon ? : } {Icon.label}
@@ -199,7 +186,7 @@ export default function OAuthAuthorize() {
-
{t('tips.common', { ns: 'oauth' })}
+
{t('tips.common', { ns: 'oauth' })}
) } diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index dfbac5d743..e4cd10175a 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -84,7 +84,7 @@ export const AppInitializer = ({ return } - const redirectUrl = resolvePostLoginRedirect(searchParams) + const redirectUrl = resolvePostLoginRedirect() if (redirectUrl) { location.replace(redirectUrl) return diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 5ad16d139f..692ae12022 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -61,8 +61,7 @@ const ParamsConfig = ({ if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { if (tempDataSetConfigs.reranking_enable && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel - && !isCurrentRerankModelValid - ) { + && !isCurrentRerankModelValid) { errMsg = t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) } } diff --git a/web/app/components/base/chat/chat/question.spec.tsx b/web/app/components/base/chat/chat/question.spec.tsx index 2f8714ef77..99c25f5659 100644 --- a/web/app/components/base/chat/chat/question.spec.tsx +++ b/web/app/components/base/chat/chat/question.spec.tsx @@ -1,7 +1,7 @@ import type { Theme } from '../embedded-chatbot/theme/theme-context' import type { ChatConfig, ChatItem, OnRegenerate } from '../types' import type { FileEntity } from '@/app/components/base/file-uploader/types' -import { act, render, screen, waitFor } from '@testing-library/react' +import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import copy from 'copy-to-clipboard' import * as React from 'react' @@ -180,7 +180,7 @@ describe('Question component', () => { await user.clear(textbox) await user.type(textbox, 'Edited question') - const resendBtn = screen.getByRole('button', { name: /chat.resend/i }) + const resendBtn = screen.getByRole('button', { name: /operation.save/i }) await user.click(resendBtn) await waitFor(() => { @@ -209,6 +209,91 @@ describe('Question component', () => { }) }) + it('should confirm editing when Enter is pressed', async () => { + const user = userEvent.setup() + const onRegenerate = vi.fn() as unknown as OnRegenerate + + renderWithProvider(makeItem(), onRegenerate) + + await user.click(screen.getByTestId('edit-btn')) + const textbox = await screen.findByRole('textbox') + + await user.clear(textbox) + await user.type(textbox, 'Edited with Enter') + + fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' }) + + await waitFor(() => { + expect(onRegenerate).toHaveBeenCalledWith(makeItem(), { message: 'Edited with Enter', files: [] }) + }) + }) + + it('should insert a new line when Shift+Enter is pressed', async () => { + const user = userEvent.setup() + const onRegenerate = vi.fn() as unknown as OnRegenerate + + renderWithProvider(makeItem(), onRegenerate) + + await user.click(screen.getByTestId('edit-btn')) + const textbox = await screen.findByRole('textbox') + + await user.clear(textbox) + await user.type(textbox, 'Line 1') + await user.type(textbox, '{Shift>}{Enter}{/Shift}') + + expect(textbox).toHaveValue('Line 1\n') + expect(onRegenerate).not.toHaveBeenCalled() + }) + + it('should not confirm editing when Enter is pressed during IME composition', () => { + const onRegenerate = vi.fn() as unknown as OnRegenerate + + renderWithProvider(makeItem(), onRegenerate) + + fireEvent.click(screen.getByTestId('edit-btn')) + const textbox = screen.getByRole('textbox') + + fireEvent.compositionStart(textbox) + fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' }) + + expect(onRegenerate).not.toHaveBeenCalled() + expect(textbox).toHaveValue('This is the question content') + }) + + it('should keep text unchanged and suppress Enter if a new composition starts before previous composition-end timer finishes', async () => { + vi.useFakeTimers() + + try { + const onRegenerate = vi.fn() as unknown as OnRegenerate + renderWithProvider(makeItem(), onRegenerate) + + fireEvent.click(screen.getByTestId('edit-btn')) + const textbox = screen.getByRole('textbox') + fireEvent.change(textbox, { target: { value: 'IME guard text' } }) + + fireEvent.compositionStart(textbox) + fireEvent.compositionEnd(textbox) + fireEvent.compositionStart(textbox) + + vi.advanceTimersByTime(50) + + const blockedEnterEvent = new KeyboardEvent('keydown', { key: 'Enter', code: 'Enter', bubbles: true, cancelable: true }) + textbox.dispatchEvent(blockedEnterEvent) + expect(onRegenerate).not.toHaveBeenCalled() + expect(blockedEnterEvent.defaultPrevented).toBe(true) + expect(textbox).toHaveValue('IME guard text') + + fireEvent.compositionEnd(textbox) + vi.advanceTimersByTime(50) + + fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' }) + expect(onRegenerate).toHaveBeenCalledWith(makeItem(), { message: 'IME guard text', files: [] }) + } + finally { + vi.useRealTimers() + } + }) + it('should switch siblings when prev/next buttons are clicked', async () => { const user = userEvent.setup() const switchSibling = vi.fn() diff --git a/web/app/components/base/chat/chat/question.tsx b/web/app/components/base/chat/chat/question.tsx index 4c8c7f262d..6eceadf6ea 100644 --- a/web/app/components/base/chat/chat/question.tsx +++ b/web/app/components/base/chat/chat/question.tsx @@ -56,6 +56,8 @@ const Question: FC = ({ const [editedContent, setEditedContent] = useState(content) const [contentWidth, setContentWidth] = useState(0) const contentRef = useRef(null) + const isComposingRef = useRef(false) + const compositionEndTimerRef = useRef | null>(null) const handleEdit = useCallback(() => { setIsEditing(true) @@ -63,15 +65,62 @@ const Question: FC = ({ }, [content]) const handleResend = useCallback(() => { + if (compositionEndTimerRef.current) { + clearTimeout(compositionEndTimerRef.current) + compositionEndTimerRef.current = null + } + isComposingRef.current = false setIsEditing(false) onRegenerate?.(item, { message: editedContent, files: message_files }) }, [editedContent, message_files, item, onRegenerate]) const handleCancelEditing = useCallback(() => { + if (compositionEndTimerRef.current) { + clearTimeout(compositionEndTimerRef.current) + compositionEndTimerRef.current = null + } + isComposingRef.current = false setIsEditing(false) setEditedContent(content) }, [content]) + const handleEditInputKeyDown = useCallback((e: React.KeyboardEvent) => { + if (e.key !== 'Enter' || e.shiftKey) + return + + if (e.nativeEvent.isComposing) + return + + if (isComposingRef.current) { + e.preventDefault() + return + } + + e.preventDefault() + handleResend() + }, [handleResend]) + + const clearCompositionEndTimer = useCallback(() => { + if (!compositionEndTimerRef.current) + return + + clearTimeout(compositionEndTimerRef.current) + compositionEndTimerRef.current = null + }, []) + + const handleCompositionStart = useCallback(() => { + clearCompositionEndTimer() + isComposingRef.current = true + }, [clearCompositionEndTimer]) + + const handleCompositionEnd = useCallback(() => { + clearCompositionEndTimer() + compositionEndTimerRef.current = setTimeout(() => { + isComposingRef.current = false + compositionEndTimerRef.current = null + }, 50) + }, [clearCompositionEndTimer]) + const handleSwitchSibling = useCallback((direction: 'prev' | 'next') => { if (direction === 'prev') { if (item.prevSibling) @@ -100,6 +149,12 @@ const Question: FC = ({ } }, []) + useEffect(() => { + return () => { + clearCompositionEndTimer() + } + }, [clearCompositionEndTimer]) + return (
@@ -128,13 +183,17 @@ const Question: FC = ({
{ !!message_files?.length && ( = ({ {!isEditing ? : ( -
-
+
+