diff --git a/.agent/skills b/.agent/skills deleted file mode 120000 index 454b8427cd..0000000000 --- a/.agent/skills +++ /dev/null @@ -1 +0,0 @@ -../.claude/skills \ No newline at end of file diff --git a/.agent/skills/component-refactoring b/.agent/skills/component-refactoring new file mode 120000 index 0000000000..53ae67e2f2 --- /dev/null +++ b/.agent/skills/component-refactoring @@ -0,0 +1 @@ +../../.agents/skills/component-refactoring \ No newline at end of file diff --git a/.agent/skills/frontend-code-review b/.agent/skills/frontend-code-review new file mode 120000 index 0000000000..55654ffbd7 --- /dev/null +++ b/.agent/skills/frontend-code-review @@ -0,0 +1 @@ +../../.agents/skills/frontend-code-review \ No newline at end of file diff --git a/.agent/skills/frontend-testing b/.agent/skills/frontend-testing new file mode 120000 index 0000000000..092cec7745 --- /dev/null +++ b/.agent/skills/frontend-testing @@ -0,0 +1 @@ +../../.agents/skills/frontend-testing \ No newline at end of file diff --git a/.agent/skills/orpc-contract-first b/.agent/skills/orpc-contract-first new file mode 120000 index 0000000000..da47b335c7 --- /dev/null +++ b/.agent/skills/orpc-contract-first @@ -0,0 +1 @@ +../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.agent/skills/skill-creator b/.agent/skills/skill-creator new file mode 120000 index 0000000000..b87455490f --- /dev/null +++ b/.agent/skills/skill-creator @@ -0,0 +1 @@ +../../.agents/skills/skill-creator \ No newline at end of file diff --git a/.agent/skills/vercel-react-best-practices b/.agent/skills/vercel-react-best-practices new file mode 120000 index 0000000000..e567923b32 --- /dev/null +++ b/.agent/skills/vercel-react-best-practices @@ -0,0 +1 @@ +../../.agents/skills/vercel-react-best-practices \ No newline at end of file diff --git a/.agent/skills/web-design-guidelines b/.agent/skills/web-design-guidelines new file mode 120000 index 0000000000..886b26ded7 --- /dev/null +++ b/.agent/skills/web-design-guidelines @@ -0,0 +1 @@ +../../.agents/skills/web-design-guidelines \ No newline at end of file diff --git a/.claude/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md similarity index 100% rename from .claude/skills/component-refactoring/SKILL.md rename to .agents/skills/component-refactoring/SKILL.md diff --git a/.claude/skills/component-refactoring/references/complexity-patterns.md b/.agents/skills/component-refactoring/references/complexity-patterns.md similarity index 100% rename from .claude/skills/component-refactoring/references/complexity-patterns.md rename to .agents/skills/component-refactoring/references/complexity-patterns.md diff --git a/.claude/skills/component-refactoring/references/component-splitting.md b/.agents/skills/component-refactoring/references/component-splitting.md similarity index 100% rename from .claude/skills/component-refactoring/references/component-splitting.md rename to .agents/skills/component-refactoring/references/component-splitting.md diff --git a/.claude/skills/component-refactoring/references/hook-extraction.md b/.agents/skills/component-refactoring/references/hook-extraction.md similarity index 100% rename from .claude/skills/component-refactoring/references/hook-extraction.md rename to .agents/skills/component-refactoring/references/hook-extraction.md diff --git a/.claude/skills/frontend-code-review/SKILL.md b/.agents/skills/frontend-code-review/SKILL.md similarity index 100% rename from .claude/skills/frontend-code-review/SKILL.md rename to .agents/skills/frontend-code-review/SKILL.md diff --git a/.claude/skills/frontend-code-review/references/business-logic.md b/.agents/skills/frontend-code-review/references/business-logic.md similarity index 100% rename from .claude/skills/frontend-code-review/references/business-logic.md rename to .agents/skills/frontend-code-review/references/business-logic.md diff --git a/.claude/skills/frontend-code-review/references/code-quality.md b/.agents/skills/frontend-code-review/references/code-quality.md similarity index 100% rename from .claude/skills/frontend-code-review/references/code-quality.md rename to .agents/skills/frontend-code-review/references/code-quality.md diff --git a/.claude/skills/frontend-code-review/references/performance.md b/.agents/skills/frontend-code-review/references/performance.md similarity index 100% rename from .claude/skills/frontend-code-review/references/performance.md rename to .agents/skills/frontend-code-review/references/performance.md diff --git a/.claude/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md similarity index 100% rename from .claude/skills/frontend-testing/SKILL.md rename to .agents/skills/frontend-testing/SKILL.md diff --git a/.claude/skills/frontend-testing/assets/component-test.template.tsx b/.agents/skills/frontend-testing/assets/component-test.template.tsx similarity index 100% rename from .claude/skills/frontend-testing/assets/component-test.template.tsx rename to .agents/skills/frontend-testing/assets/component-test.template.tsx diff --git a/.claude/skills/frontend-testing/assets/hook-test.template.ts b/.agents/skills/frontend-testing/assets/hook-test.template.ts similarity index 100% rename from .claude/skills/frontend-testing/assets/hook-test.template.ts rename to .agents/skills/frontend-testing/assets/hook-test.template.ts diff --git a/.claude/skills/frontend-testing/assets/utility-test.template.ts b/.agents/skills/frontend-testing/assets/utility-test.template.ts similarity index 100% rename from .claude/skills/frontend-testing/assets/utility-test.template.ts rename to .agents/skills/frontend-testing/assets/utility-test.template.ts diff --git a/.claude/skills/frontend-testing/references/async-testing.md b/.agents/skills/frontend-testing/references/async-testing.md similarity index 100% rename from .claude/skills/frontend-testing/references/async-testing.md rename to .agents/skills/frontend-testing/references/async-testing.md diff --git a/.claude/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md similarity index 100% rename from .claude/skills/frontend-testing/references/checklist.md rename to .agents/skills/frontend-testing/references/checklist.md diff --git a/.claude/skills/frontend-testing/references/common-patterns.md b/.agents/skills/frontend-testing/references/common-patterns.md similarity index 100% rename from .claude/skills/frontend-testing/references/common-patterns.md rename to .agents/skills/frontend-testing/references/common-patterns.md diff --git a/.claude/skills/frontend-testing/references/domain-components.md b/.agents/skills/frontend-testing/references/domain-components.md similarity index 100% rename from .claude/skills/frontend-testing/references/domain-components.md rename to .agents/skills/frontend-testing/references/domain-components.md diff --git a/.claude/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md similarity index 100% rename from .claude/skills/frontend-testing/references/mocking.md rename to .agents/skills/frontend-testing/references/mocking.md diff --git a/.claude/skills/frontend-testing/references/workflow.md b/.agents/skills/frontend-testing/references/workflow.md similarity index 100% rename from .claude/skills/frontend-testing/references/workflow.md rename to .agents/skills/frontend-testing/references/workflow.md diff --git a/.claude/skills/orpc-contract-first/SKILL.md b/.agents/skills/orpc-contract-first/SKILL.md similarity index 100% rename from .claude/skills/orpc-contract-first/SKILL.md rename to .agents/skills/orpc-contract-first/SKILL.md diff --git a/.claude/skills/skill-creator/SKILL.md b/.agents/skills/skill-creator/SKILL.md similarity index 100% rename from .claude/skills/skill-creator/SKILL.md rename to .agents/skills/skill-creator/SKILL.md diff --git a/.claude/skills/skill-creator/references/output-patterns.md b/.agents/skills/skill-creator/references/output-patterns.md similarity index 100% rename from .claude/skills/skill-creator/references/output-patterns.md rename to .agents/skills/skill-creator/references/output-patterns.md diff --git a/.claude/skills/skill-creator/references/workflows.md b/.agents/skills/skill-creator/references/workflows.md similarity index 100% rename from .claude/skills/skill-creator/references/workflows.md rename to .agents/skills/skill-creator/references/workflows.md diff --git a/.claude/skills/skill-creator/scripts/init_skill.py b/.agents/skills/skill-creator/scripts/init_skill.py similarity index 100% rename from .claude/skills/skill-creator/scripts/init_skill.py rename to .agents/skills/skill-creator/scripts/init_skill.py diff --git a/.claude/skills/skill-creator/scripts/package_skill.py b/.agents/skills/skill-creator/scripts/package_skill.py similarity index 100% rename from .claude/skills/skill-creator/scripts/package_skill.py rename to .agents/skills/skill-creator/scripts/package_skill.py diff --git a/.claude/skills/skill-creator/scripts/quick_validate.py b/.agents/skills/skill-creator/scripts/quick_validate.py similarity index 100% rename from .claude/skills/skill-creator/scripts/quick_validate.py rename to .agents/skills/skill-creator/scripts/quick_validate.py diff --git a/.claude/skills/vercel-react-best-practices/AGENTS.md b/.agents/skills/vercel-react-best-practices/AGENTS.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/AGENTS.md rename to .agents/skills/vercel-react-best-practices/AGENTS.md diff --git a/.claude/skills/vercel-react-best-practices/SKILL.md b/.agents/skills/vercel-react-best-practices/SKILL.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/SKILL.md rename to .agents/skills/vercel-react-best-practices/SKILL.md diff --git a/.claude/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md b/.agents/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md rename to .agents/skills/vercel-react-best-practices/rules/advanced-event-handler-refs.md diff --git a/.claude/skills/vercel-react-best-practices/rules/advanced-use-latest.md b/.agents/skills/vercel-react-best-practices/rules/advanced-use-latest.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/advanced-use-latest.md rename to .agents/skills/vercel-react-best-practices/rules/advanced-use-latest.md diff --git a/.claude/skills/vercel-react-best-practices/rules/async-api-routes.md b/.agents/skills/vercel-react-best-practices/rules/async-api-routes.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/async-api-routes.md rename to .agents/skills/vercel-react-best-practices/rules/async-api-routes.md diff --git a/.claude/skills/vercel-react-best-practices/rules/async-defer-await.md b/.agents/skills/vercel-react-best-practices/rules/async-defer-await.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/async-defer-await.md rename to .agents/skills/vercel-react-best-practices/rules/async-defer-await.md diff --git a/.claude/skills/vercel-react-best-practices/rules/async-dependencies.md b/.agents/skills/vercel-react-best-practices/rules/async-dependencies.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/async-dependencies.md rename to .agents/skills/vercel-react-best-practices/rules/async-dependencies.md diff --git a/.claude/skills/vercel-react-best-practices/rules/async-parallel.md b/.agents/skills/vercel-react-best-practices/rules/async-parallel.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/async-parallel.md rename to .agents/skills/vercel-react-best-practices/rules/async-parallel.md diff --git a/.claude/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md b/.agents/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md rename to .agents/skills/vercel-react-best-practices/rules/async-suspense-boundaries.md diff --git a/.claude/skills/vercel-react-best-practices/rules/bundle-barrel-imports.md b/.agents/skills/vercel-react-best-practices/rules/bundle-barrel-imports.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/bundle-barrel-imports.md rename to .agents/skills/vercel-react-best-practices/rules/bundle-barrel-imports.md diff --git a/.claude/skills/vercel-react-best-practices/rules/bundle-conditional.md b/.agents/skills/vercel-react-best-practices/rules/bundle-conditional.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/bundle-conditional.md rename to .agents/skills/vercel-react-best-practices/rules/bundle-conditional.md diff --git a/.claude/skills/vercel-react-best-practices/rules/bundle-defer-third-party.md b/.agents/skills/vercel-react-best-practices/rules/bundle-defer-third-party.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/bundle-defer-third-party.md rename to .agents/skills/vercel-react-best-practices/rules/bundle-defer-third-party.md diff --git a/.claude/skills/vercel-react-best-practices/rules/bundle-dynamic-imports.md b/.agents/skills/vercel-react-best-practices/rules/bundle-dynamic-imports.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/bundle-dynamic-imports.md rename to .agents/skills/vercel-react-best-practices/rules/bundle-dynamic-imports.md diff --git a/.claude/skills/vercel-react-best-practices/rules/bundle-preload.md b/.agents/skills/vercel-react-best-practices/rules/bundle-preload.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/bundle-preload.md rename to .agents/skills/vercel-react-best-practices/rules/bundle-preload.md diff --git a/.claude/skills/vercel-react-best-practices/rules/client-event-listeners.md b/.agents/skills/vercel-react-best-practices/rules/client-event-listeners.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/client-event-listeners.md rename to .agents/skills/vercel-react-best-practices/rules/client-event-listeners.md diff --git a/.claude/skills/vercel-react-best-practices/rules/client-localstorage-schema.md b/.agents/skills/vercel-react-best-practices/rules/client-localstorage-schema.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/client-localstorage-schema.md rename to .agents/skills/vercel-react-best-practices/rules/client-localstorage-schema.md diff --git a/.claude/skills/vercel-react-best-practices/rules/client-passive-event-listeners.md b/.agents/skills/vercel-react-best-practices/rules/client-passive-event-listeners.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/client-passive-event-listeners.md rename to .agents/skills/vercel-react-best-practices/rules/client-passive-event-listeners.md diff --git a/.claude/skills/vercel-react-best-practices/rules/client-swr-dedup.md b/.agents/skills/vercel-react-best-practices/rules/client-swr-dedup.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/client-swr-dedup.md rename to .agents/skills/vercel-react-best-practices/rules/client-swr-dedup.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-batch-dom-css.md b/.agents/skills/vercel-react-best-practices/rules/js-batch-dom-css.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-batch-dom-css.md rename to .agents/skills/vercel-react-best-practices/rules/js-batch-dom-css.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-cache-function-results.md b/.agents/skills/vercel-react-best-practices/rules/js-cache-function-results.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-cache-function-results.md rename to .agents/skills/vercel-react-best-practices/rules/js-cache-function-results.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-cache-property-access.md b/.agents/skills/vercel-react-best-practices/rules/js-cache-property-access.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-cache-property-access.md rename to .agents/skills/vercel-react-best-practices/rules/js-cache-property-access.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-cache-storage.md b/.agents/skills/vercel-react-best-practices/rules/js-cache-storage.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-cache-storage.md rename to .agents/skills/vercel-react-best-practices/rules/js-cache-storage.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-combine-iterations.md b/.agents/skills/vercel-react-best-practices/rules/js-combine-iterations.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-combine-iterations.md rename to .agents/skills/vercel-react-best-practices/rules/js-combine-iterations.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-early-exit.md b/.agents/skills/vercel-react-best-practices/rules/js-early-exit.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-early-exit.md rename to .agents/skills/vercel-react-best-practices/rules/js-early-exit.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-hoist-regexp.md b/.agents/skills/vercel-react-best-practices/rules/js-hoist-regexp.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-hoist-regexp.md rename to .agents/skills/vercel-react-best-practices/rules/js-hoist-regexp.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-index-maps.md b/.agents/skills/vercel-react-best-practices/rules/js-index-maps.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-index-maps.md rename to .agents/skills/vercel-react-best-practices/rules/js-index-maps.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-length-check-first.md b/.agents/skills/vercel-react-best-practices/rules/js-length-check-first.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-length-check-first.md rename to .agents/skills/vercel-react-best-practices/rules/js-length-check-first.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-min-max-loop.md b/.agents/skills/vercel-react-best-practices/rules/js-min-max-loop.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-min-max-loop.md rename to .agents/skills/vercel-react-best-practices/rules/js-min-max-loop.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-set-map-lookups.md b/.agents/skills/vercel-react-best-practices/rules/js-set-map-lookups.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-set-map-lookups.md rename to .agents/skills/vercel-react-best-practices/rules/js-set-map-lookups.md diff --git a/.claude/skills/vercel-react-best-practices/rules/js-tosorted-immutable.md b/.agents/skills/vercel-react-best-practices/rules/js-tosorted-immutable.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/js-tosorted-immutable.md rename to .agents/skills/vercel-react-best-practices/rules/js-tosorted-immutable.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-activity.md b/.agents/skills/vercel-react-best-practices/rules/rendering-activity.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rendering-activity.md rename to .agents/skills/vercel-react-best-practices/rules/rendering-activity.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-animate-svg-wrapper.md b/.agents/skills/vercel-react-best-practices/rules/rendering-animate-svg-wrapper.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rendering-animate-svg-wrapper.md rename to .agents/skills/vercel-react-best-practices/rules/rendering-animate-svg-wrapper.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-conditional-render.md b/.agents/skills/vercel-react-best-practices/rules/rendering-conditional-render.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rendering-conditional-render.md rename to .agents/skills/vercel-react-best-practices/rules/rendering-conditional-render.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-content-visibility.md b/.agents/skills/vercel-react-best-practices/rules/rendering-content-visibility.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rendering-content-visibility.md rename to .agents/skills/vercel-react-best-practices/rules/rendering-content-visibility.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-hoist-jsx.md b/.agents/skills/vercel-react-best-practices/rules/rendering-hoist-jsx.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rendering-hoist-jsx.md rename to .agents/skills/vercel-react-best-practices/rules/rendering-hoist-jsx.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md b/.agents/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md rename to .agents/skills/vercel-react-best-practices/rules/rendering-hydration-no-flicker.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rendering-svg-precision.md b/.agents/skills/vercel-react-best-practices/rules/rendering-svg-precision.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rendering-svg-precision.md rename to .agents/skills/vercel-react-best-practices/rules/rendering-svg-precision.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-defer-reads.md b/.agents/skills/vercel-react-best-practices/rules/rerender-defer-reads.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rerender-defer-reads.md rename to .agents/skills/vercel-react-best-practices/rules/rerender-defer-reads.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-dependencies.md b/.agents/skills/vercel-react-best-practices/rules/rerender-dependencies.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rerender-dependencies.md rename to .agents/skills/vercel-react-best-practices/rules/rerender-dependencies.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-derived-state.md b/.agents/skills/vercel-react-best-practices/rules/rerender-derived-state.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rerender-derived-state.md rename to .agents/skills/vercel-react-best-practices/rules/rerender-derived-state.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md b/.agents/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md rename to .agents/skills/vercel-react-best-practices/rules/rerender-functional-setstate.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md b/.agents/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md rename to .agents/skills/vercel-react-best-practices/rules/rerender-lazy-state-init.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-memo.md b/.agents/skills/vercel-react-best-practices/rules/rerender-memo.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rerender-memo.md rename to .agents/skills/vercel-react-best-practices/rules/rerender-memo.md diff --git a/.claude/skills/vercel-react-best-practices/rules/rerender-transitions.md b/.agents/skills/vercel-react-best-practices/rules/rerender-transitions.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/rerender-transitions.md rename to .agents/skills/vercel-react-best-practices/rules/rerender-transitions.md diff --git a/.claude/skills/vercel-react-best-practices/rules/server-after-nonblocking.md b/.agents/skills/vercel-react-best-practices/rules/server-after-nonblocking.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/server-after-nonblocking.md rename to .agents/skills/vercel-react-best-practices/rules/server-after-nonblocking.md diff --git a/.claude/skills/vercel-react-best-practices/rules/server-cache-lru.md b/.agents/skills/vercel-react-best-practices/rules/server-cache-lru.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/server-cache-lru.md rename to .agents/skills/vercel-react-best-practices/rules/server-cache-lru.md diff --git a/.claude/skills/vercel-react-best-practices/rules/server-cache-react.md b/.agents/skills/vercel-react-best-practices/rules/server-cache-react.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/server-cache-react.md rename to .agents/skills/vercel-react-best-practices/rules/server-cache-react.md diff --git a/.claude/skills/vercel-react-best-practices/rules/server-parallel-fetching.md b/.agents/skills/vercel-react-best-practices/rules/server-parallel-fetching.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/server-parallel-fetching.md rename to .agents/skills/vercel-react-best-practices/rules/server-parallel-fetching.md diff --git a/.claude/skills/vercel-react-best-practices/rules/server-serialization.md b/.agents/skills/vercel-react-best-practices/rules/server-serialization.md similarity index 100% rename from .claude/skills/vercel-react-best-practices/rules/server-serialization.md rename to .agents/skills/vercel-react-best-practices/rules/server-serialization.md diff --git a/.agents/skills/web-design-guidelines/SKILL.md b/.agents/skills/web-design-guidelines/SKILL.md new file mode 100644 index 0000000000..ceae92ab31 --- /dev/null +++ b/.agents/skills/web-design-guidelines/SKILL.md @@ -0,0 +1,39 @@ +--- +name: web-design-guidelines +description: Review UI code for Web Interface Guidelines compliance. Use when asked to "review my UI", "check accessibility", "audit design", "review UX", or "check my site against best practices". +metadata: + author: vercel + version: "1.0.0" + argument-hint: +--- + +# Web Interface Guidelines + +Review files for compliance with Web Interface Guidelines. + +## How It Works + +1. Fetch the latest guidelines from the source URL below +2. Read the specified files (or prompt user for files/pattern) +3. Check against all rules in the fetched guidelines +4. Output findings in the terse `file:line` format + +## Guidelines Source + +Fetch fresh guidelines before each review: + +``` +https://raw.githubusercontent.com/vercel-labs/web-interface-guidelines/main/command.md +``` + +Use WebFetch to retrieve the latest rules. The fetched content contains all the rules and output format instructions. + +## Usage + +When a user provides a file or pattern argument: +1. Fetch guidelines from the source URL above +2. Read the specified files +3. Apply all rules from the fetched guidelines +4. Output findings using the format specified in the guidelines + +If no files specified, ask the user which files to review. diff --git a/.claude/skills/component-refactoring b/.claude/skills/component-refactoring new file mode 120000 index 0000000000..53ae67e2f2 --- /dev/null +++ b/.claude/skills/component-refactoring @@ -0,0 +1 @@ +../../.agents/skills/component-refactoring \ No newline at end of file diff --git a/.claude/skills/frontend-code-review b/.claude/skills/frontend-code-review new file mode 120000 index 0000000000..55654ffbd7 --- /dev/null +++ b/.claude/skills/frontend-code-review @@ -0,0 +1 @@ +../../.agents/skills/frontend-code-review \ No newline at end of file diff --git a/.claude/skills/frontend-testing b/.claude/skills/frontend-testing new file mode 120000 index 0000000000..092cec7745 --- /dev/null +++ b/.claude/skills/frontend-testing @@ -0,0 +1 @@ +../../.agents/skills/frontend-testing \ No newline at end of file diff --git a/.claude/skills/orpc-contract-first b/.claude/skills/orpc-contract-first new file mode 120000 index 0000000000..da47b335c7 --- /dev/null +++ b/.claude/skills/orpc-contract-first @@ -0,0 +1 @@ +../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.claude/skills/skill-creator b/.claude/skills/skill-creator new file mode 120000 index 0000000000..b87455490f --- /dev/null +++ b/.claude/skills/skill-creator @@ -0,0 +1 @@ +../../.agents/skills/skill-creator \ No newline at end of file diff --git a/.claude/skills/vercel-react-best-practices b/.claude/skills/vercel-react-best-practices new file mode 120000 index 0000000000..e567923b32 --- /dev/null +++ b/.claude/skills/vercel-react-best-practices @@ -0,0 +1 @@ +../../.agents/skills/vercel-react-best-practices \ No newline at end of file diff --git a/.claude/skills/web-design-guidelines b/.claude/skills/web-design-guidelines new file mode 120000 index 0000000000..886b26ded7 --- /dev/null +++ b/.claude/skills/web-design-guidelines @@ -0,0 +1 @@ +../../.agents/skills/web-design-guidelines \ No newline at end of file diff --git a/.codex/skills b/.codex/skills deleted file mode 120000 index 454b8427cd..0000000000 --- a/.codex/skills +++ /dev/null @@ -1 +0,0 @@ -../.claude/skills \ No newline at end of file diff --git a/.codex/skills/component-refactoring b/.codex/skills/component-refactoring new file mode 120000 index 0000000000..53ae67e2f2 --- /dev/null +++ b/.codex/skills/component-refactoring @@ -0,0 +1 @@ +../../.agents/skills/component-refactoring \ No newline at end of file diff --git a/.codex/skills/frontend-code-review b/.codex/skills/frontend-code-review new file mode 120000 index 0000000000..55654ffbd7 --- /dev/null +++ b/.codex/skills/frontend-code-review @@ -0,0 +1 @@ +../../.agents/skills/frontend-code-review \ No newline at end of file diff --git a/.codex/skills/frontend-testing b/.codex/skills/frontend-testing new file mode 120000 index 0000000000..092cec7745 --- /dev/null +++ b/.codex/skills/frontend-testing @@ -0,0 +1 @@ +../../.agents/skills/frontend-testing \ No newline at end of file diff --git a/.codex/skills/orpc-contract-first b/.codex/skills/orpc-contract-first new file mode 120000 index 0000000000..da47b335c7 --- /dev/null +++ b/.codex/skills/orpc-contract-first @@ -0,0 +1 @@ +../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.codex/skills/skill-creator b/.codex/skills/skill-creator new file mode 120000 index 0000000000..b87455490f --- /dev/null +++ b/.codex/skills/skill-creator @@ -0,0 +1 @@ +../../.agents/skills/skill-creator \ No newline at end of file diff --git a/.codex/skills/vercel-react-best-practices b/.codex/skills/vercel-react-best-practices new file mode 120000 index 0000000000..e567923b32 --- /dev/null +++ b/.codex/skills/vercel-react-best-practices @@ -0,0 +1 @@ +../../.agents/skills/vercel-react-best-practices \ No newline at end of file diff --git a/.codex/skills/web-design-guidelines b/.codex/skills/web-design-guidelines new file mode 120000 index 0000000000..886b26ded7 --- /dev/null +++ b/.codex/skills/web-design-guidelines @@ -0,0 +1 @@ +../../.agents/skills/web-design-guidelines \ No newline at end of file diff --git a/.cursor/skills/component-refactoring b/.cursor/skills/component-refactoring new file mode 120000 index 0000000000..53ae67e2f2 --- /dev/null +++ b/.cursor/skills/component-refactoring @@ -0,0 +1 @@ +../../.agents/skills/component-refactoring \ No newline at end of file diff --git a/.cursor/skills/frontend-code-review b/.cursor/skills/frontend-code-review new file mode 120000 index 0000000000..55654ffbd7 --- /dev/null +++ b/.cursor/skills/frontend-code-review @@ -0,0 +1 @@ +../../.agents/skills/frontend-code-review \ No newline at end of file diff --git a/.cursor/skills/frontend-testing b/.cursor/skills/frontend-testing new file mode 120000 index 0000000000..092cec7745 --- /dev/null +++ b/.cursor/skills/frontend-testing @@ -0,0 +1 @@ +../../.agents/skills/frontend-testing \ No newline at end of file diff --git a/.cursor/skills/orpc-contract-first b/.cursor/skills/orpc-contract-first new file mode 120000 index 0000000000..da47b335c7 --- /dev/null +++ b/.cursor/skills/orpc-contract-first @@ -0,0 +1 @@ +../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.cursor/skills/skill-creator b/.cursor/skills/skill-creator new file mode 120000 index 0000000000..b87455490f --- /dev/null +++ b/.cursor/skills/skill-creator @@ -0,0 +1 @@ +../../.agents/skills/skill-creator \ No newline at end of file diff --git a/.cursor/skills/vercel-react-best-practices b/.cursor/skills/vercel-react-best-practices new file mode 120000 index 0000000000..e567923b32 --- /dev/null +++ b/.cursor/skills/vercel-react-best-practices @@ -0,0 +1 @@ +../../.agents/skills/vercel-react-best-practices \ No newline at end of file diff --git a/.cursor/skills/web-design-guidelines b/.cursor/skills/web-design-guidelines new file mode 120000 index 0000000000..886b26ded7 --- /dev/null +++ b/.cursor/skills/web-design-guidelines @@ -0,0 +1 @@ +../../.agents/skills/web-design-guidelines \ No newline at end of file diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 220f77e5ce..637593b9de 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -8,7 +8,7 @@ pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc -echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc +echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc diff --git a/.gemini/skills/component-refactoring b/.gemini/skills/component-refactoring new file mode 120000 index 0000000000..53ae67e2f2 --- /dev/null +++ b/.gemini/skills/component-refactoring @@ -0,0 +1 @@ +../../.agents/skills/component-refactoring \ No newline at end of file diff --git a/.gemini/skills/frontend-code-review b/.gemini/skills/frontend-code-review new file mode 120000 index 0000000000..55654ffbd7 --- /dev/null +++ b/.gemini/skills/frontend-code-review @@ -0,0 +1 @@ +../../.agents/skills/frontend-code-review \ No newline at end of file diff --git a/.gemini/skills/frontend-testing b/.gemini/skills/frontend-testing new file mode 120000 index 0000000000..092cec7745 --- /dev/null +++ b/.gemini/skills/frontend-testing @@ -0,0 +1 @@ +../../.agents/skills/frontend-testing \ No newline at end of file diff --git a/.gemini/skills/orpc-contract-first b/.gemini/skills/orpc-contract-first new file mode 120000 index 0000000000..da47b335c7 --- /dev/null +++ b/.gemini/skills/orpc-contract-first @@ -0,0 +1 @@ +../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.gemini/skills/skill-creator b/.gemini/skills/skill-creator new file mode 120000 index 0000000000..b87455490f --- /dev/null +++ b/.gemini/skills/skill-creator @@ -0,0 +1 @@ +../../.agents/skills/skill-creator \ No newline at end of file diff --git a/.gemini/skills/vercel-react-best-practices b/.gemini/skills/vercel-react-best-practices new file mode 120000 index 0000000000..e567923b32 --- /dev/null +++ b/.gemini/skills/vercel-react-best-practices @@ -0,0 +1 @@ +../../.agents/skills/vercel-react-best-practices \ No newline at end of file diff --git a/.gemini/skills/web-design-guidelines b/.gemini/skills/web-design-guidelines new file mode 120000 index 0000000000..886b26ded7 --- /dev/null +++ b/.gemini/skills/web-design-guidelines @@ -0,0 +1 @@ +../../.agents/skills/web-design-guidelines \ No newline at end of file diff --git a/.github/skills/component-refactoring b/.github/skills/component-refactoring new file mode 120000 index 0000000000..53ae67e2f2 --- /dev/null +++ b/.github/skills/component-refactoring @@ -0,0 +1 @@ +../../.agents/skills/component-refactoring \ No newline at end of file diff --git a/.github/skills/frontend-code-review b/.github/skills/frontend-code-review new file mode 120000 index 0000000000..55654ffbd7 --- /dev/null +++ b/.github/skills/frontend-code-review @@ -0,0 +1 @@ +../../.agents/skills/frontend-code-review \ No newline at end of file diff --git a/.github/skills/frontend-testing b/.github/skills/frontend-testing new file mode 120000 index 0000000000..092cec7745 --- /dev/null +++ b/.github/skills/frontend-testing @@ -0,0 +1 @@ +../../.agents/skills/frontend-testing \ No newline at end of file diff --git a/.github/skills/orpc-contract-first b/.github/skills/orpc-contract-first new file mode 120000 index 0000000000..da47b335c7 --- /dev/null +++ b/.github/skills/orpc-contract-first @@ -0,0 +1 @@ +../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.github/skills/skill-creator b/.github/skills/skill-creator new file mode 120000 index 0000000000..b87455490f --- /dev/null +++ b/.github/skills/skill-creator @@ -0,0 +1 @@ +../../.agents/skills/skill-creator \ No newline at end of file diff --git a/.github/skills/vercel-react-best-practices b/.github/skills/vercel-react-best-practices new file mode 120000 index 0000000000..e567923b32 --- /dev/null +++ b/.github/skills/vercel-react-best-practices @@ -0,0 +1 @@ +../../.agents/skills/vercel-react-best-practices \ No newline at end of file diff --git a/.github/skills/web-design-guidelines b/.github/skills/web-design-guidelines new file mode 120000 index 0000000000..886b26ded7 --- /dev/null +++ b/.github/skills/web-design-guidelines @@ -0,0 +1 @@ +../../.agents/skills/web-design-guidelines \ No newline at end of file diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index ff006324bb..4a8c61e7d2 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -79,9 +79,32 @@ jobs: find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + package_json_file: web/package.json + run_install: false + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: 24 + cache: pnpm + cache-dependency-path: ./web/pnpm-lock.yaml + + - name: Install web dependencies + run: | + cd web + pnpm install --frozen-lockfile + + - name: ESLint autofix + run: | + cd web + pnpm lint:fix || true + # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - name: mdformat run: | - uvx --python 3.13 mdformat . --exclude ".claude/skills/**" + uvx --python 3.13 mdformat . --exclude ".agents/skills/**" - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 5551030f1e..fdc05d1d65 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -125,7 +125,7 @@ jobs: - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run type-check:tsgo + run: pnpm run type-check - name: Web dead code check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/api/.env.example b/api/.env.example index 15981c14b8..c3b1474549 100644 --- a/api/.env.example +++ b/api/.env.example @@ -715,4 +715,5 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5 SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 +SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 diff --git a/api/.importlinter b/api/.importlinter index 2dec958788..b676e97591 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -27,7 +27,9 @@ ignore_imports = core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events core.workflow.nodes.loop.loop_node -> core.workflow.graph_events - core.workflow.nodes.node_factory -> core.workflow.graph + core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory + core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory + core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine core.workflow.nodes.iteration.iteration_node -> core.workflow.graph core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels @@ -57,6 +59,252 @@ ignore_imports = core.workflow.graph_engine.manager -> extensions.ext_redis core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis +[importlinter:contract:workflow-external-imports] +name = Workflow External Imports +type = forbidden +source_modules = + core.workflow +forbidden_modules = + configs + controllers + extensions + models + services + tasks + core.agent + core.app + core.base + core.callback_handler + core.datasource + core.db + core.entities + core.errors + core.extension + core.external_data_tool + core.file + core.helper + core.hosting_configuration + core.indexing_runner + core.llm_generator + core.logging + core.mcp + core.memory + core.model_manager + core.moderation + core.ops + core.plugin + core.prompt + core.provider_manager + core.rag + core.repositories + core.schemas + core.tools + core.trigger + core.variables +ignore_imports = + core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory + core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis + core.workflow.graph_engine.layers.observability -> configs + core.workflow.graph_engine.layers.observability -> extensions.otel.runtime + core.workflow.graph_engine.layers.persistence -> core.ops.ops_trace_manager + core.workflow.graph_engine.worker_management.worker_pool -> configs + core.workflow.nodes.agent.agent_node -> core.model_manager + core.workflow.nodes.agent.agent_node -> core.provider_manager + core.workflow.nodes.agent.agent_node -> core.tools.tool_manager + core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor + core.workflow.nodes.datasource.datasource_node -> models.model + core.workflow.nodes.datasource.datasource_node -> models.tools + core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service + core.workflow.nodes.document_extractor.node -> configs + core.workflow.nodes.document_extractor.node -> core.file.file_manager + core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy + core.workflow.nodes.http_request.entities -> configs + core.workflow.nodes.http_request.executor -> configs + core.workflow.nodes.http_request.executor -> core.file.file_manager + core.workflow.nodes.http_request.node -> configs + core.workflow.nodes.http_request.node -> core.tools.tool_file_manager + core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.llm.llm_utils -> configs + core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities + core.workflow.nodes.llm.llm_utils -> core.file.models + core.workflow.nodes.llm.llm_utils -> core.model_manager + core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.llm.llm_utils -> models.model + core.workflow.nodes.llm.llm_utils -> models.provider + core.workflow.nodes.llm.llm_utils -> services.credit_pool_service + core.workflow.nodes.llm.node -> core.tools.signature + core.workflow.nodes.template_transform.template_transform_node -> configs + core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler + core.workflow.nodes.tool.tool_node -> core.tools.tool_engine + core.workflow.nodes.tool.tool_node -> core.tools.tool_manager + core.workflow.workflow_entry -> configs + core.workflow.workflow_entry -> models.workflow + core.workflow.nodes.agent.agent_node -> core.agent.entities + core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities + core.workflow.graph_engine.layers.persistence -> core.app.entities.app_invoke_entities + core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model + core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.start.entities -> core.app.app_config.entities + core.workflow.nodes.start.start_node -> core.app.app_config.entities + core.workflow.workflow_entry -> core.app.apps.exc + core.workflow.workflow_entry -> core.app.entities.app_invoke_entities + core.workflow.workflow_entry -> core.app.workflow.node_factory + core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager + core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager + core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager + core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager + core.workflow.node_events.node -> core.file + core.workflow.nodes.agent.agent_node -> core.file + core.workflow.nodes.datasource.datasource_node -> core.file + core.workflow.nodes.datasource.datasource_node -> core.file.enums + core.workflow.nodes.document_extractor.node -> core.file + core.workflow.nodes.http_request.executor -> core.file.enums + core.workflow.nodes.http_request.node -> core.file + core.workflow.nodes.http_request.node -> core.file.file_manager + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models + core.workflow.nodes.list_operator.node -> core.file + core.workflow.nodes.llm.file_saver -> core.file + core.workflow.nodes.llm.llm_utils -> core.variables.segments + core.workflow.nodes.llm.node -> core.file + core.workflow.nodes.llm.node -> core.file.file_manager + core.workflow.nodes.llm.node -> core.file.models + core.workflow.nodes.loop.entities -> core.variables.types + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file + core.workflow.nodes.protocols -> core.file + core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models + core.workflow.nodes.tool.tool_node -> core.file + core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer + core.workflow.nodes.tool.tool_node -> models + core.workflow.nodes.trigger_webhook.node -> core.file + core.workflow.runtime.variable_pool -> core.file + core.workflow.runtime.variable_pool -> core.file.file_manager + core.workflow.system_variable -> core.file.models + core.workflow.utils.condition.processor -> core.file + core.workflow.utils.condition.processor -> core.file.file_manager + core.workflow.workflow_entry -> core.file.models + core.workflow.workflow_type_encoder -> core.file.models + core.workflow.nodes.agent.agent_node -> models.model + core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider + core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider + core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider + core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor + core.workflow.nodes.datasource.datasource_node -> core.variables.variables + core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy + core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy + core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy + core.workflow.nodes.llm.node -> core.helper.code_executor + core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor + core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors + core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output + core.workflow.nodes.llm.node -> core.model_manager + core.workflow.graph_engine.layers.persistence -> core.ops.entities.trace_entity + core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform + core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities + core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util + core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods + core.workflow.nodes.llm.node -> models.dataset + core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer + core.workflow.nodes.llm.file_saver -> core.tools.signature + core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager + core.workflow.nodes.tool.tool_node -> core.tools.errors + core.workflow.conversation_variable_updater -> core.variables + core.workflow.graph_engine.entities.commands -> core.variables.variables + core.workflow.nodes.agent.agent_node -> core.variables.segments + core.workflow.nodes.answer.answer_node -> core.variables + core.workflow.nodes.code.code_node -> core.variables.segments + core.workflow.nodes.code.code_node -> core.variables.types + core.workflow.nodes.code.entities -> core.variables.types + core.workflow.nodes.datasource.datasource_node -> core.variables.segments + core.workflow.nodes.document_extractor.node -> core.variables + core.workflow.nodes.document_extractor.node -> core.variables.segments + core.workflow.nodes.http_request.executor -> core.variables.segments + core.workflow.nodes.http_request.node -> core.variables.segments + core.workflow.nodes.iteration.iteration_node -> core.variables + core.workflow.nodes.iteration.iteration_node -> core.variables.segments + core.workflow.nodes.iteration.iteration_node -> core.variables.variables + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments + core.workflow.nodes.list_operator.node -> core.variables + core.workflow.nodes.list_operator.node -> core.variables.segments + core.workflow.nodes.llm.node -> core.variables + core.workflow.nodes.loop.loop_node -> core.variables + core.workflow.nodes.parameter_extractor.entities -> core.variables.types + core.workflow.nodes.parameter_extractor.exc -> core.variables.types + core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types + core.workflow.nodes.tool.tool_node -> core.variables.segments + core.workflow.nodes.tool.tool_node -> core.variables.variables + core.workflow.nodes.trigger_webhook.node -> core.variables.types + core.workflow.nodes.trigger_webhook.node -> core.variables.variables + core.workflow.nodes.variable_aggregator.entities -> core.variables.types + core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments + core.workflow.nodes.variable_assigner.common.helpers -> core.variables + core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts + core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types + core.workflow.nodes.variable_assigner.v1.node -> core.variables + core.workflow.nodes.variable_assigner.v2.helpers -> core.variables + core.workflow.nodes.variable_assigner.v2.node -> core.variables + core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts + core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments + core.workflow.runtime.read_only_wrappers -> core.variables.segments + core.workflow.runtime.variable_pool -> core.variables + core.workflow.runtime.variable_pool -> core.variables.consts + core.workflow.runtime.variable_pool -> core.variables.segments + core.workflow.runtime.variable_pool -> core.variables.variables + core.workflow.utils.condition.processor -> core.variables + core.workflow.utils.condition.processor -> core.variables.segments + core.workflow.variable_loader -> core.variables + core.workflow.variable_loader -> core.variables.consts + core.workflow.workflow_type_encoder -> core.variables + core.workflow.graph_engine.manager -> extensions.ext_redis + core.workflow.nodes.agent.agent_node -> extensions.ext_database + core.workflow.nodes.datasource.datasource_node -> extensions.ext_database + core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database + core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis + core.workflow.nodes.llm.file_saver -> extensions.ext_database + core.workflow.nodes.llm.llm_utils -> extensions.ext_database + core.workflow.nodes.llm.node -> extensions.ext_database + core.workflow.nodes.tool.tool_node -> extensions.ext_database + core.workflow.workflow_entry -> extensions.otel.runtime + core.workflow.nodes.agent.agent_node -> models + core.workflow.nodes.base.node -> models.enums + core.workflow.nodes.llm.llm_utils -> models.provider_ids + core.workflow.nodes.llm.node -> models.model + core.workflow.workflow_entry -> models.enums + core.workflow.nodes.agent.agent_node -> services + core.workflow.nodes.tool.tool_node -> services + [importlinter:contract:rsc] name = RSC type = layers diff --git a/api/agent-notes/controllers/console/datasets/datasets_document.py.md b/api/agent-notes/controllers/console/datasets/datasets_document.py.md deleted file mode 100644 index b100249981..0000000000 --- a/api/agent-notes/controllers/console/datasets/datasets_document.py.md +++ /dev/null @@ -1,52 +0,0 @@ -## Purpose - -`api/controllers/console/datasets/datasets_document.py` contains the console (authenticated) APIs for managing dataset documents (list/create/update/delete, processing controls, estimates, etc.). - -## Storage model (uploaded files) - -- For local file uploads into a knowledge base, the binary is stored via `extensions.ext_storage.storage` under the key: - - `upload_files//.` -- File metadata is stored in the `upload_files` table (`UploadFile` model), keyed by `UploadFile.id`. -- Dataset `Document` records reference the uploaded file via: - - `Document.data_source_info.upload_file_id` - -## Download endpoint - -- `GET /datasets//documents//download` - - - Only supported when `Document.data_source_type == "upload_file"`. - - Performs dataset permission + tenant checks via `DocumentResource.get_document(...)`. - - Delegates `Document -> UploadFile` validation and signed URL generation to `DocumentService.get_document_download_url(...)`. - - Applies `cloud_edition_billing_rate_limit_check("knowledge")` to match other KB operations. - - Response body is **only**: `{ "url": "" }`. - -- `POST /datasets//documents/download-zip` - - - Accepts `{ "document_ids": ["..."] }` (upload-file only). - - Returns `application/zip` as a single attachment download. - - Rationale: browsers often block multiple automatic downloads; a ZIP avoids that limitation. - - Applies `cloud_edition_billing_rate_limit_check("knowledge")`. - - Delegates dataset permission checks, document/upload-file validation, and download-name generation to - `DocumentService.prepare_document_batch_download_zip(...)` before streaming the ZIP. - -## Verification plan - -- Upload a document from a local file into a dataset. -- Call the download endpoint and confirm it returns a signed URL. -- Open the URL and confirm: - - Response headers force download (`Content-Disposition`), and - - Downloaded bytes match the uploaded file. -- Select multiple uploaded-file documents and download as ZIP; confirm all selected files exist in the archive. - -## Shared helper - -- `DocumentService.get_document_download_url(document)` resolves the `UploadFile` and signs a download URL. -- `DocumentService.prepare_document_batch_download_zip(...)` performs dataset permission checks, batches - document + upload file lookups, preserves request order, and generates the client-visible ZIP filename. -- Internal helpers now live in `DocumentService` (`_get_upload_file_id_for_upload_file_document(...)`, - `_get_upload_file_for_upload_file_document(...)`, `_get_upload_files_by_document_id_for_zip_download(...)`). -- ZIP packing is handled by `FileService.build_upload_files_zip_tempfile(...)`, which also: - - sanitizes entry names to avoid path traversal, and - - deduplicates names while preserving extensions (e.g., `doc.txt` → `doc (1).txt`). - Streaming the response and deferring cleanup is handled by the route via `send_file(path, ...)` + `ExitStack` + - `response.call_on_close(...)` (the file is deleted when the response is closed). diff --git a/api/agent-notes/services/dataset_service.py.md b/api/agent-notes/services/dataset_service.py.md deleted file mode 100644 index b68ef345f5..0000000000 --- a/api/agent-notes/services/dataset_service.py.md +++ /dev/null @@ -1,18 +0,0 @@ -## Purpose - -`api/services/dataset_service.py` hosts dataset/document service logic used by console and API controllers. - -## Batch document operations - -- Batch document workflows should avoid N+1 database queries by using set-based lookups. -- Tenant checks must be enforced consistently across dataset/document operations. -- `DocumentService.get_documents_by_ids(...)` fetches documents for a dataset using `id.in_(...)`. -- `FileService.get_upload_files_by_ids(...)` performs tenant-scoped batch lookup for `UploadFile` (dedupes ids with `set(...)`). -- `DocumentService.get_document_download_url(...)` and `prepare_document_batch_download_zip(...)` handle - dataset/document permission checks plus `Document -> UploadFile` validation for download endpoints. - -## Verification plan - -- Exercise document list and download endpoints that use the service helpers. -- Confirm batch download uses constant query count for documents + upload files. -- Request a ZIP with a missing document id and confirm a 404 is returned. diff --git a/api/agent-notes/services/file_service.py.md b/api/agent-notes/services/file_service.py.md deleted file mode 100644 index cf394a1c05..0000000000 --- a/api/agent-notes/services/file_service.py.md +++ /dev/null @@ -1,35 +0,0 @@ -## Purpose - -`api/services/file_service.py` owns business logic around `UploadFile` objects: upload validation, storage persistence, -previews/generators, and deletion. - -## Key invariants - -- All storage I/O goes through `extensions.ext_storage.storage`. -- Uploaded file keys follow: `upload_files//.`. -- Upload validation is enforced in `FileService.upload_file(...)` (blocked extensions, size limits, dataset-only types). - -## Batch lookup helpers - -- `FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)` is the canonical tenant-scoped batch loader for - `UploadFile`. - -## Dataset document download helpers - -The dataset document download/ZIP endpoints now delegate “Document → UploadFile” validation and permission checks to -`DocumentService` (`api/services/dataset_service.py`). `FileService` stays focused on generic `UploadFile` operations -(uploading, previews, deletion), plus generic ZIP serving. - -### ZIP serving - -- `FileService.build_upload_files_zip_tempfile(...)` builds a ZIP from `UploadFile` objects and yields a seeked - tempfile **path** so callers can stream it (e.g., `send_file(path, ...)`) without hitting "read of closed file" - issues from file-handle lifecycle during streamed responses. -- Flask `send_file(...)` and the `ExitStack`/`call_on_close(...)` cleanup pattern are handled in the route layer. - -## Verification plan - -- Unit: `api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py` - - Verify signed URL generation for upload-file documents and ZIP download behavior for multiple documents. -- Unit: `api/tests/unit_tests/services/test_file_service_zip_and_lookup.py` - - Verify ZIP packing produces a valid, openable archive and preserves file content. diff --git a/api/agent-notes/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py.md b/api/agent-notes/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py.md deleted file mode 100644 index 8f78dacde8..0000000000 --- a/api/agent-notes/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py.md +++ /dev/null @@ -1,28 +0,0 @@ -## Purpose - -Unit tests for the console dataset document download endpoint: - -- `GET /datasets//documents//download` - -## Testing approach - -- Uses `Flask.test_request_context()` and calls the `Resource.get(...)` method directly. -- Monkeypatches console decorators (`login_required`, `setup_required`, rate limit) to no-ops to keep the test focused. -- Mocks: - - `DatasetService.get_dataset` / `check_dataset_permission` - - `DocumentService.get_document` for single-file download tests - - `DocumentService.get_documents_by_ids` + `FileService.get_upload_files_by_ids` for ZIP download tests - - `FileService.get_upload_files_by_ids` for `UploadFile` lookups in single-file tests - - `services.dataset_service.file_helpers.get_signed_file_url` to return a deterministic URL -- Document mocks include `id` fields so batch lookups can map documents by id. - -## Covered cases - -- Success returns `{ "url": "" }` for upload-file documents. -- 404 when document is not `upload_file`. -- 404 when `upload_file_id` is missing. -- 404 when referenced `UploadFile` row does not exist. -- 403 when document tenant does not match current tenant. -- Batch ZIP download returns `application/zip` for upload-file documents. -- Batch ZIP download rejects non-upload-file documents. -- Batch ZIP download uses a random `.zip` attachment name (`download_name`), so tests only assert the suffix. diff --git a/api/agent-notes/tests/unit_tests/services/test_file_service_zip_and_lookup.py.md b/api/agent-notes/tests/unit_tests/services/test_file_service_zip_and_lookup.py.md deleted file mode 100644 index dbcdf26f10..0000000000 --- a/api/agent-notes/tests/unit_tests/services/test_file_service_zip_and_lookup.py.md +++ /dev/null @@ -1,18 +0,0 @@ -## Purpose - -Unit tests for `api/services/file_service.py` helper methods that are not covered by higher-level controller tests. - -## What’s covered - -- `FileService.build_upload_files_zip_tempfile(...)` - - ZIP entry name sanitization (no directory components / traversal) - - name deduplication while preserving extensions - - writing streamed bytes from `storage.load(...)` into ZIP entries - - yields a tempfile path so callers can open/stream the ZIP without holding a live file handle -- `FileService.get_upload_files_by_ids(...)` - - returns `{}` for empty id lists - - returns an id-keyed mapping for non-empty lists - -## Notes - -- These tests intentionally stub `storage.load` and `db.session.scalars(...).all()` to avoid needing a real DB/storage. diff --git a/api/agent_skills/infra.md b/api/agent_skills/infra.md deleted file mode 100644 index bc36c7bf64..0000000000 --- a/api/agent_skills/infra.md +++ /dev/null @@ -1,96 +0,0 @@ -## Configuration - -- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly. -- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`. -- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing. -- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`. - -## Dependencies - -- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`. -- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group. -- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current. - -## Storage & Files - -- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend. -- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads. -- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly. -- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform. - -## Redis & Shared State - -- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`. -- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`. - -## Models - -- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`). -- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn. -- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories. -- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below. - -## Vector Stores - -- Vector client implementations live in `core/rag/datasource/vdb/`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`. -- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`. -- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions. -- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations. - -## Observability & OTEL - -- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads. -- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints. -- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`). -- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`. - -## Ops Integrations - -- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above. -- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules. -- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata. - -## Controllers, Services, Core - -- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`. -- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs). -- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`. - -## Plugins, Tools, Providers - -- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation. -- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`. -- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way. -- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application. -- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config). -- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly. - -## Async Workloads - -see `agent_skills/trigger.md` for more detailed documentation. - -- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`. -- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc. -- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs. - -## Database & Migrations - -- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`. -- Generate migrations with `uv run --project api flask db revision --autogenerate -m ""`, then review the diff; never hand-edit the database outside Alembic. -- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history. -- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables. - -## CLI Commands - -- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask `. -- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour. -- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations. -- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR. -- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes). - -## When You Add Features - -- Check for an existing helper or service before writing a new util. -- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`. -- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations). -- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes. diff --git a/api/agent_skills/plugin.md b/api/agent_skills/plugin.md deleted file mode 100644 index 954ddd236b..0000000000 --- a/api/agent_skills/plugin.md +++ /dev/null @@ -1 +0,0 @@ -// TBD diff --git a/api/agent_skills/plugin_oauth.md b/api/agent_skills/plugin_oauth.md deleted file mode 100644 index 954ddd236b..0000000000 --- a/api/agent_skills/plugin_oauth.md +++ /dev/null @@ -1 +0,0 @@ -// TBD diff --git a/api/agent_skills/trigger.md b/api/agent_skills/trigger.md deleted file mode 100644 index f4b076332c..0000000000 --- a/api/agent_skills/trigger.md +++ /dev/null @@ -1,53 +0,0 @@ -## Overview - -Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node. - -## Trigger nodes - -- `UserInput` -- `Trigger Webhook` -- `Trigger Schedule` -- `Trigger Plugin` - -### UserInput - -Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app` - -1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool. -1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node. -1. For its detailed implementation, please refer to `core/workflow/nodes/start` - -### Trigger Webhook - -Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`. - -Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution. - -### Trigger Schedule - -`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help. - -To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published. - -### Trigger Plugin - -`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it. - -1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint` -1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details. - -A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one. - -## Worker Pool / Async Task - -All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`. - -The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`. - -## Debug Strategy - -Dify divided users into 2 groups: builders / end users. - -Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`. - -A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type. diff --git a/api/commands.py b/api/commands.py index aa7b731a27..3d68de4cb4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -950,6 +950,346 @@ def clean_workflow_runs( ) +@click.command( + "archive-workflow-runs", + help="Archive workflow runs for paid plan tenants to S3-compatible storage.", +) +@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") +@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created at or after this timestamp (UTC if no timezone).", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created before this timestamp (UTC if no timezone).", +) +@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") +@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") +@click.option("--dry-run", is_flag=True, help="Preview without archiving.") +@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") +def archive_workflow_runs( + tenant_ids: str | None, + before_days: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + batch_size: int, + workers: int, + limit: int | None, + dry_run: bool, + delete_after_archive: bool, +): + """ + Archive workflow runs for paid plan tenants older than the specified days. + + This command archives the following tables to storage: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + + The workflow_runs and workflow_app_logs tables are preserved for UI listing. + """ + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + run_started_at = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting workflow run archiving at {run_started_at.isoformat()}.", + fg="white", + ) + ) + + if (start_from is None) ^ (end_before is None): + click.echo(click.style("start-from and end-before must be provided together.", fg="red")) + return + + if (from_days_ago is None) ^ (to_days_ago is None): + click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) + return + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) + return + if from_days_ago <= to_days_ago: + click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) + return + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + if start_from and end_before and start_from >= end_before: + click.echo(click.style("start-from must be earlier than end-before.", fg="red")) + return + if workers < 1: + click.echo(click.style("workers must be at least 1.", fg="red")) + return + + archiver = WorkflowRunArchiver( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + workers=workers, + tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, + limit=limit, + dry_run=dry_run, + delete_after_archive=delete_after_archive, + ) + summary = archiver.run() + click.echo( + click.style( + f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="cyan", + ) + ) + + run_finished_at = datetime.datetime.now(datetime.UTC) + elapsed = run_finished_at - run_started_at + click.echo( + click.style( + f"Workflow run archiving completed. start={run_started_at.isoformat()} " + f"end={run_finished_at.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "restore-workflow-runs", + help="Restore archived workflow runs from S3-compatible storage.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to restore.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") +@click.option("--dry-run", is_flag=True, help="Preview without restoring.") +def restore_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + workers: int, + limit: int, + dry_run: bool, +): + """ + Restore an archived workflow run from storage to the database. + + This restores the following tables: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + """ + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch restore.") + if workers < 1: + raise click.BadParameter("workers must be at least 1") + + start_time = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.", + fg="white", + ) + ) + + restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) + if run_id: + results = [restorer.restore_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = restorer.restore_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Restore completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + +@click.command( + "delete-archived-workflow-runs", + help="Delete archived workflow runs from the database.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to delete.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") +@click.option("--dry-run", is_flag=True, help="Preview without deleting.") +def delete_archived_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + limit: int, + dry_run: bool, +): + """ + Delete archived workflow runs from the database. + """ + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch delete.") + + start_time = datetime.datetime.now(datetime.UTC) + target_desc = f"workflow run {run_id}" if run_id else "workflow runs" + click.echo( + click.style( + f"Starting delete of {target_desc} at {start_time.isoformat()}.", + fg="white", + ) + ) + + deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) + if run_id: + results = [deleter.delete_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = deleter.delete_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + for result in results: + if result.success: + click.echo( + click.style( + f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " + f"workflow run {result.run_id} (tenant={result.tenant_id})", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Failed to delete workflow run {result.run_id}: {result.error}", + fg="red", + ) + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Delete completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + @click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") @click.command("clear-orphaned-file-records", help="Clear orphaned file records.") def clear_orphaned_file_records(force: bool): diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index cf71a33fa8..786094f295 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -965,6 +965,16 @@ class MailConfig(BaseSettings): default=None, ) + ENABLE_TRIAL_APP: bool = Field( + description="Enable trial app", + default=False, + ) + + ENABLE_EXPLORE_BANNER: bool = Field( + description="Enable explore banner", + default=False, + ) + class RagEtlConfig(BaseSettings): """ @@ -1298,6 +1308,10 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings): description="Retention days for sandbox expired workflow_run records and message records", default=30, ) + SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field( + description="Lock TTL for sandbox expired records clean task in seconds", + default=90000, + ) class FeatureConfig( diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 4b693cd91f..2d465c8cf4 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -3,13 +3,14 @@ Flask App Context - Flask implementation of AppContext interface. """ import contextvars +import threading from collections.abc import Generator from contextlib import contextmanager from typing import Any, final from flask import Flask, current_app, g -from context import register_context_capturer +from core.workflow.context import register_context_capturer from core.workflow.context.execution_context import ( AppContext, IExecutionContext, @@ -118,6 +119,7 @@ class FlaskExecutionContext: self._context_vars = context_vars self._user = user self._flask_app = flask_app + self._local = threading.local() @property def app_context(self) -> FlaskAppContext: @@ -136,47 +138,39 @@ class FlaskExecutionContext: def __enter__(self) -> "FlaskExecutionContext": """Enter the Flask execution context.""" - # Restore context variables + # Restore non-Flask context variables to avoid leaking Flask tokens across threads for var, val in self._context_vars.items(): var.set(val) - # Save current user from g if available - saved_user = None - if hasattr(g, "_login_user"): - saved_user = g._login_user - # Enter Flask app context - self._cm = self._app_context.enter() - self._cm.__enter__() + cm = self._app_context.enter() + self._local.cm = cm + cm.__enter__() # Restore user in new app context - if saved_user is not None: - g._login_user = saved_user + if self._user is not None: + g._login_user = self._user return self def __exit__(self, *args: Any) -> None: """Exit the Flask execution context.""" - if hasattr(self, "_cm"): - self._cm.__exit__(*args) + cm = getattr(self._local, "cm", None) + if cm is not None: + cm.__exit__(*args) @contextmanager def enter(self) -> Generator[None, None, None]: """Enter Flask execution context as context manager.""" - # Restore context variables + # Restore non-Flask context variables to avoid leaking Flask tokens across threads for var, val in self._context_vars.items(): var.set(val) - # Save current user from g if available - saved_user = None - if hasattr(g, "_login_user"): - saved_user = g._login_user - # Enter Flask app context with self._flask_app.app_context(): # Restore user in new app context - if saved_user is not None: - g._login_user = saved_user + if self._user is not None: + g._login_user = self._user yield diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ad878fc266..fdc9aabc83 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -107,10 +107,12 @@ from .datasets.rag_pipeline import ( # Import explore controllers from .explore import ( + banner, installed_app, parameter, recommended_app, saved_message, + trial, ) # Import tag controllers @@ -145,6 +147,7 @@ __all__ = [ "apikey", "app", "audio", + "banner", "billing", "bp", "completion", @@ -198,6 +201,7 @@ __all__ = [ "statistic", "tags", "tool_providers", + "trial", "trigger_providers", "version", "website", diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a25ca5ef51..e1ee2c24b8 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud from core.db.session_factory import session_factory from extensions.ext_database import db from libs.token import extract_access_token -from models.model import App, InstalledApp, RecommendedApp +from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp P = ParamSpec("P") R = TypeVar("R") @@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel): language: str = Field(...) category: str = Field(...) position: int = Field(...) + can_trial: bool = Field(default=False) + trial_limit: int = Field(default=0) @field_validator("language") @classmethod @@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel): return supported_language(value) +class InsertExploreBannerPayload(BaseModel): + category: str = Field(...) + title: str = Field(...) + description: str = Field(...) + img_src: str = Field(..., alias="img-src") + language: str = Field(default="en-US") + link: str = Field(...) + sort: int = Field(...) + + @field_validator("language") + @classmethod + def validate_language(cls, value: str) -> str: + return supported_language(value) + + model_config = {"populate_by_name": True} + + console_ns.schema_model( InsertExploreAppPayload.__name__, InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) +console_ns.schema_model( + InsertExploreBannerPayload.__name__, + InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + def admin_required(view: Callable[P, R]): @wraps(view) @@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource): ) db.session.add(recommended_app) + if payload.can_trial: + trial_app = db.session.execute( + select(TrialApp).where(TrialApp.app_id == payload.app_id) + ).scalar_one_or_none() + if not trial_app: + db.session.add( + TrialApp( + app_id=payload.app_id, + tenant_id=app.tenant_id, + trial_limit=payload.trial_limit, + ) + ) + else: + trial_app.trial_limit = payload.trial_limit app.is_public = True db.session.commit() @@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource): recommended_app.category = payload.category recommended_app.position = payload.position + if payload.can_trial: + trial_app = db.session.execute( + select(TrialApp).where(TrialApp.app_id == payload.app_id) + ).scalar_one_or_none() + if not trial_app: + db.session.add( + TrialApp( + app_id=payload.app_id, + tenant_id=app.tenant_id, + trial_limit=payload.trial_limit, + ) + ) + else: + trial_app.trial_limit = payload.trial_limit app.is_public = True db.session.commit() @@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource): for installed_app in installed_apps: session.delete(installed_app) + trial_app = session.execute( + select(TrialApp).where(TrialApp.app_id == recommended_app.app_id) + ).scalar_one_or_none() + if trial_app: + session.delete(trial_app) + db.session.delete(recommended_app) db.session.commit() return {"result": "success"}, 204 + + +@console_ns.route("/admin/insert-explore-banner") +class InsertExploreBannerApi(Resource): + @console_ns.doc("insert_explore_banner") + @console_ns.doc(description="Insert an explore banner") + @console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__]) + @console_ns.response(201, "Banner inserted successfully") + @only_edition_cloud + @admin_required + def post(self): + payload = InsertExploreBannerPayload.model_validate(console_ns.payload) + + content = { + "category": payload.category, + "title": payload.title, + "description": payload.description, + "img-src": payload.img_src, + } + + banner = ExporleBanner( + content=content, + link=payload.link, + sort=payload.sort, + language=payload.language, + ) + db.session.add(banner) + db.session.commit() + + return {"result": "success"}, 201 + + +@console_ns.route("/admin/delete-explore-banner/") +class DeleteExploreBannerApi(Resource): + @console_ns.doc("delete_explore_banner") + @console_ns.doc(description="Delete an explore banner") + @console_ns.doc(params={"banner_id": "Banner ID to delete"}) + @console_ns.response(204, "Banner deleted successfully") + @only_edition_cloud + @admin_required + def delete(self, banner_id): + banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none() + if not banner: + raise NotFound(f"Banner '{banner_id}' is not found") + + db.session.delete(banner) + db.session.commit() + + return {"result": "success"}, 204 diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index fbd7901646..6b4bd6755a 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException): error_code = "rate_limit_error" description = "Rate Limit Error" code = 429 + + +class NeedAddIdsError(BaseHTTPException): + error_code = "need_add_ids" + description = "Need to add ids." + code = 400 diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index fa67fb8154..6736f24a2e 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.workflow.enums import WorkflowExecutionStatus from extensions.ext_database import db -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.workflow_app_log_fields import ( + build_workflow_app_log_pagination_model, + build_workflow_archived_log_pagination_model, +) from libs.login import login_required from models import App from models.model import AppMode @@ -61,6 +64,7 @@ console_ns.schema_model( # Register model for flask_restx to avoid dict type issues in Swagger workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) +workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns) @console_ns.route("/apps//workflow-app-logs") @@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination + + +@console_ns.route("/apps//workflow-archived-logs") +class WorkflowArchivedLogApi(Resource): + @console_ns.doc("get_workflow_archived_logs") + @console_ns.doc(description="Get workflow archived execution logs") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) + @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.WORKFLOW]) + @marshal_with(workflow_archived_log_pagination_model) + def get(self, app_model: App): + """ + Get workflow archived logs + """ + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + workflow_app_service = WorkflowAppService() + with Session(db.engine) as session: + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( + session=session, + app_model=app_model, + page=args.page, + limit=args.limit, + ) + + return workflow_app_log_pagination diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 8f1871f1e9..fa74f8aea1 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,12 +1,15 @@ +from datetime import UTC, datetime, timedelta from typing import Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields from fields.workflow_run_fields import ( @@ -19,14 +22,17 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value from libs.login import current_user, login_required -from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom +from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME from services.workflow_run_service import WorkflowRunService # Workflow run status choices for filtering WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] +EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -93,6 +99,15 @@ workflow_run_node_execution_list_model = console_ns.model( "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy ) +workflow_run_export_fields = console_ns.model( + "WorkflowRunExport", + { + "status": fields.String(description="Export status: success/failed"), + "presigned_url": fields.String(description="Pre-signed URL for download", required=False), + "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False), + }, +) + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -181,6 +196,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource): return result +@console_ns.route("/apps//workflow-runs//export") +class WorkflowRunExportApi(Resource): + @console_ns.doc("get_workflow_run_export_url") + @console_ns.doc(description="Generate a download URL for an archived workflow run.") + @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) + @console_ns.response(200, "Export URL generated", workflow_run_export_fields) + @setup_required + @login_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App, run_id: str): + tenant_id = str(app_model.tenant_id) + app_id = str(app_model.id) + run_id_str = str(run_id) + + run_created_at = db.session.scalar( + select(WorkflowArchiveLog.run_created_at) + .where( + WorkflowArchiveLog.tenant_id == tenant_id, + WorkflowArchiveLog.app_id == app_id, + WorkflowArchiveLog.workflow_run_id == run_id_str, + ) + .limit(1) + ) + if not run_created_at: + return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404 + + prefix = ( + f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/" + f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}" + ) + archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + try: + archive_storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + return {"code": "archive_storage_not_configured", "message": str(e)}, 500 + + presigned_url = archive_storage.generate_presigned_url( + archive_key, + expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS, + ) + expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS) + return { + "status": "success", + "presigned_url": presigned_url, + "presigned_url_expires_at": expires_at.isoformat(), + }, 200 + + @console_ns.route("/apps//advanced-chat/workflow-runs/count") class AdvancedChatAppWorkflowRunCountApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs_count") diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 9bb2718f89..e687d980fa 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None: return app_model +def _load_app_model_with_trial(app_id: str) -> App | None: + app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first() + return app_model + + def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func: Callable[P1, R1]): @wraps(view_func) @@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li return decorator else: return decorator(view) + + +def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None): + def decorator(view_func: Callable[P, R]): + @wraps(view_func) + def decorated_view(*args: P.args, **kwargs: P.kwargs): + if not kwargs.get("app_id"): + raise ValueError("missing app_id in path parameters") + + app_id = kwargs.get("app_id") + app_id = str(app_id) + + del kwargs["app_id"] + + app_model = _load_app_model_with_trial(app_id) + + if not app_model: + raise AppNotFoundError() + + app_mode = AppMode.value_of(app_model.mode) + + if mode is not None: + if isinstance(mode, list): + modes = mode + else: + modes = [mode] + + if app_mode not in modes: + mode_values = {m.value for m in modes} + raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") + + kwargs["app_model"] = app_model + + return view_func(*args, **kwargs) + + return decorated_view + + if view is None: + return decorator + else: + return decorator(view) diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py new file mode 100644 index 0000000000..da306fbc9d --- /dev/null +++ b/api/controllers/console/explore/banner.py @@ -0,0 +1,43 @@ +from flask import request +from flask_restx import Resource + +from controllers.console import api +from controllers.console.explore.wraps import explore_banner_enabled +from extensions.ext_database import db +from models.model import ExporleBanner + + +class BannerApi(Resource): + """Resource for banner list.""" + + @explore_banner_enabled + def get(self): + """Get banner list.""" + language = request.args.get("language", "en-US") + + # Build base query for enabled banners + base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled") + + # Try to get banners in the requested language + banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all() + + # Fallback to en-US if no banners found and language is not en-US + if not banners and language != "en-US": + banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all() + # Convert banners to serializable format + result = [] + for banner in banners: + banner_data = { + "id": banner.id, + "content": banner.content, # Already parsed as JSON by SQLAlchemy + "link": banner.link, + "sort": banner.sort, + "status": banner.status, + "created_at": banner.created_at.isoformat() if banner.created_at else None, + } + result.append(banner_data) + + return result + + +api.add_resource(BannerApi, "/explore/banners") diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 1e05ff4206..e96fa64f84 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException): error_code = "access_denied" description = "App access denied." code = 403 + + +class TrialAppNotAllowed(BaseHTTPException): + """*403* `Trial App Not Allowed` + + Raise if the user has reached the trial app limit. + """ + + error_code = "trial_app_not_allowed" + code = 403 + description = "the app is not allowed to be trial." + + +class TrialAppLimitExceeded(BaseHTTPException): + """*403* `Trial App Limit Exceeded` + + Raise if the user has exceeded the trial app limit. + """ + + error_code = "trial_app_limit_exceeded" + code = 403 + description = "The user has exceeded the trial app limit." diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 2b2f807694..362513ec1c 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -29,6 +29,7 @@ recommended_app_fields = { "category": fields.String, "position": fields.Integer, "is_listed": fields.Boolean, + "can_trial": fields.Boolean, } recommended_app_list_fields = { diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py new file mode 100644 index 0000000000..97d856bebe --- /dev/null +++ b/api/controllers/console/explore/trial.py @@ -0,0 +1,512 @@ +import logging +from typing import Any, cast + +from flask import request +from flask_restx import Resource, marshal, marshal_with, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from controllers.common.fields import Parameters as ParametersResponse +from controllers.common.fields import Site as SiteResponse +from controllers.console import api +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + ConversationCompletedError, + NeedAddIdsError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) +from controllers.console.app.wraps import get_app_model_with_trial +from controllers.console.explore.error import ( + AppSuggestedQuestionsAfterAnswerDisabledError, + NotChatAppError, + NotCompletionAppError, + NotWorkflowAppError, +) +from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from core.model_runtime.errors.invoke import InvokeError +from core.workflow.graph_engine.manager import GraphEngineManager +from extensions.ext_database import db +from fields.app_fields import app_detail_fields_with_site +from fields.dataset_fields import dataset_fields +from fields.workflow_fields import workflow_fields +from libs import helper +from libs.helper import uuid_value +from libs.login import current_user +from models import Account +from models.account import TenantStatus +from models.model import AppMode, Site +from models.workflow import Workflow +from services.app_generate_service import AppGenerateService +from services.app_service import AppService +from services.audio_service import AudioService +from services.dataset_service import DatasetService +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) +from services.errors.conversation import ConversationNotExistsError +from services.errors.llm import InvokeRateLimitError +from services.errors.message import ( + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) +from services.message_service import MessageService +from services.recommended_app_service import RecommendedAppService + +logger = logging.getLogger(__name__) + + +class TrialAppWorkflowRunApi(TrialAppResource): + def post(self, trial_app): + """ + Run workflow + """ + app_model = trial_app + if not app_model: + raise NotWorkflowAppError() + app_mode = AppMode.value_of(app_model.mode) + if app_mode != AppMode.WORKFLOW: + raise NotWorkflowAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + args = parser.parse_args() + assert current_user is not None + try: + app_id = app_model.id + user_id = current_user.id + response = AppGenerateService.generate( + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True + ) + RecommendedAppService.add_trial_app_record(app_id, user_id) + return helper.compact_generate_response(response) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialAppWorkflowTaskStopApi(TrialAppResource): + def post(self, trial_app, task_id: str): + """ + Stop workflow task + """ + app_model = trial_app + if not app_model: + raise NotWorkflowAppError() + app_mode = AppMode.value_of(app_model.mode) + if app_mode != AppMode.WORKFLOW: + raise NotWorkflowAppError() + assert current_user is not None + + # Stop using both mechanisms for backward compatibility + # Legacy stop flag mechanism (without user check) + AppQueueManager.set_stop_flag_no_user_check(task_id) + + # New graph engine command channel mechanism + GraphEngineManager.send_stop_command(task_id) + + return {"result": "success"} + + +class TrialChatApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + args = parser.parse_args() + + args["auto_generate_name"] = False + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + + # Get IDs before they might be detached from session + app_id = app_model.id + user_id = current_user.id + + response = AppGenerateService.generate( + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True + ) + RecommendedAppService.add_trial_app_record(app_id, user_id) + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialMessageSuggestedQuestionApi(TrialAppResource): + @trial_feature_enable + def get(self, trial_app, message_id): + app_model = trial_app + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + message_id = str(message_id) + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + questions = MessageService.get_suggested_questions_after_answer( + app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE + ) + except MessageNotExistsError: + raise NotFound("Message not found") + except ConversationNotExistsError: + raise NotFound("Conversation not found") + except SuggestedQuestionsAfterAnswerDisabledError: + raise AppSuggestedQuestionsAfterAnswerDisabledError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + return {"data": questions} + + +class TrialChatAudioApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + + file = request.files["file"] + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + + # Get IDs before they might be detached from session + app_id = app_model.id + user_id = current_user.id + + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) + RecommendedAppService.add_trial_app_record(app_id, user_id) + return response + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except NoAudioUploadedServiceError: + raise NoAudioUploadedError() + except AudioTooLargeServiceError as e: + raise AudioTooLargeError(str(e)) + except UnsupportedAudioTypeServiceError: + raise UnsupportedAudioTypeError() + except ProviderNotSupportSpeechToTextServiceError: + raise ProviderNotSupportSpeechToTextError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception as e: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialChatTextApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + try: + parser = reqparse.RequestParser() + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") + args = parser.parse_args() + + message_id = args.get("message_id", None) + text = args.get("text", None) + voice = args.get("voice", None) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + + # Get IDs before they might be detached from session + app_id = app_model.id + user_id = current_user.id + + response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + RecommendedAppService.add_trial_app_record(app_id, user_id) + return response + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except NoAudioUploadedServiceError: + raise NoAudioUploadedError() + except AudioTooLargeServiceError as e: + raise AudioTooLargeError(str(e)) + except UnsupportedAudioTypeServiceError: + raise UnsupportedAudioTypeError() + except ProviderNotSupportSpeechToTextServiceError: + raise ProviderNotSupportSpeechToTextError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception as e: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialCompletionApi(TrialAppResource): + @trial_feature_enable + def post(self, trial_app): + app_model = trial_app + if app_model.mode != "completion": + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") + args = parser.parse_args() + + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False + + try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + + # Get IDs before they might be detached from session + app_id = app_model.id + user_id = current_user.id + + response = AppGenerateService.generate( + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming + ) + + RecommendedAppService.add_trial_app_record(app_id, user_id) + return helper.compact_generate_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logger.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logger.exception("internal server error.") + raise InternalServerError() + + +class TrialSitApi(Resource): + """Resource for trial app sites.""" + + @trial_feature_enable + @get_app_model_with_trial + def get(self, app_model): + """Retrieve app site info. + + Returns the site configuration for the application including theme, icons, and text. + """ + site = db.session.query(Site).where(Site.app_id == app_model.id).first() + + if not site: + raise Forbidden() + + assert app_model.tenant + if app_model.tenant.status == TenantStatus.ARCHIVE: + raise Forbidden() + + return SiteResponse.model_validate(site).model_dump(mode="json") + + +class TrialAppParameterApi(Resource): + """Resource for app variables.""" + + @trial_feature_enable + @get_app_model_with_trial + def get(self, app_model): + """Retrieve app parameters.""" + + if app_model is None: + raise AppUnavailableError() + + if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow = app_model.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form(to_old_structure=True) + else: + app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get("user_input_form", []) + + parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) + return ParametersResponse.model_validate(parameters).model_dump(mode="json") + + +class AppApi(Resource): + @trial_feature_enable + @get_app_model_with_trial + @marshal_with(app_detail_fields_with_site) + def get(self, app_model): + """Get app detail""" + + app_service = AppService() + app_model = app_service.get_app(app_model) + + return app_model + + +class AppWorkflowApi(Resource): + @trial_feature_enable + @get_app_model_with_trial + @marshal_with(workflow_fields) + def get(self, app_model): + """Get workflow detail""" + if not app_model.workflow_id: + raise AppUnavailableError() + + workflow = ( + db.session.query(Workflow) + .where( + Workflow.id == app_model.workflow_id, + ) + .first() + ) + return workflow + + +class DatasetListApi(Resource): + @trial_feature_enable + @get_app_model_with_trial + def get(self, app_model): + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + ids = request.args.getlist("ids") + + tenant_id = app_model.tenant_id + if ids: + datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id) + else: + raise NeedAddIdsError() + + data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields)) + + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} + return response + + +api.add_resource(TrialChatApi, "/trial-apps//chat-messages", endpoint="trial_app_chat_completion") + +api.add_resource( + TrialMessageSuggestedQuestionApi, + "/trial-apps//messages//suggested-questions", + endpoint="trial_app_suggested_question", +) + +api.add_resource(TrialChatAudioApi, "/trial-apps//audio-to-text", endpoint="trial_app_audio") +api.add_resource(TrialChatTextApi, "/trial-apps//text-to-audio", endpoint="trial_app_text") + +api.add_resource(TrialCompletionApi, "/trial-apps//completion-messages", endpoint="trial_app_completion") + +api.add_resource(TrialSitApi, "/trial-apps//site") + +api.add_resource(TrialAppParameterApi, "/trial-apps//parameters", endpoint="trial_app_parameters") + +api.add_resource(AppApi, "/trial-apps/", endpoint="trial_app") + +api.add_resource(TrialAppWorkflowRunApi, "/trial-apps//workflows/run", endpoint="trial_app_workflow_run") +api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps//workflows/tasks//stop") + +api.add_resource(AppWorkflowApi, "/trial-apps//workflows", endpoint="trial_app_workflow") +api.add_resource(DatasetListApi, "/trial-apps//datasets", endpoint="trial_app_datasets") diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 2a97d312aa..38f0a04904 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -2,14 +2,15 @@ from collections.abc import Callable from functools import wraps from typing import Concatenate, ParamSpec, TypeVar +from flask import abort from flask_restx import Resource from werkzeug.exceptions import NotFound -from controllers.console.explore.error import AppAccessDeniedError +from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required -from models import InstalledApp +from models import AccountTrialAppRecord, App, InstalledApp, TrialApp from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | return decorator +def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): + def decorator(view: Callable[Concatenate[App, P], R]): + @wraps(view) + def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs): + current_user, _ = current_account_with_tenant() + + trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first() + + if trial_app is None: + raise TrialAppNotAllowed() + app = trial_app.app + + if app is None: + raise TrialAppNotAllowed() + + account_trial_app_record = ( + db.session.query(AccountTrialAppRecord) + .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id) + .first() + ) + if account_trial_app_record: + if account_trial_app_record.count >= trial_app.trial_limit: + raise TrialAppLimitExceeded() + + return view(app, *args, **kwargs) + + return decorated + + if view: + return decorator(view) + return decorator + + +def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]: + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if not features.enable_trial_app: + abort(403, "Trial app feature is not enabled.") + return view(*args, **kwargs) + + return decorated + + +def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]: + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if not features.enable_explore_banner: + abort(403, "Explore banner feature is not enabled.") + return view(*args, **kwargs) + + return decorated + + class InstalledAppResource(Resource): # must be reversed if there are multiple decorators @@ -80,3 +136,13 @@ class InstalledAppResource(Resource): account_initialization_required, login_required, ] + + +class TrialAppResource(Resource): + # must be reversed if there are multiple decorators + + method_decorators = [ + trial_app_required, + account_initialization_required, + login_required, + ] diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 6951c906e9..d3811e2d1b 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,6 +1,7 @@ from flask_restx import Resource, fields +from werkzeug.exceptions import Unauthorized -from libs.login import current_account_with_tenant, login_required +from libs.login import current_account_with_tenant, current_user, login_required from services.feature_service import FeatureService from . import console_ns @@ -39,5 +40,21 @@ class SystemFeatureApi(Resource): ), ) def get(self): - """Get system-wide feature configuration""" - return FeatureService.get_system_features().model_dump() + """Get system-wide feature configuration + + NOTE: This endpoint is unauthenticated by design, as it provides system features + data required for dashboard initialization. + + Authentication would create circular dependency (can't login without dashboard loading). + + Only non-sensitive configuration data should be returned by this endpoint. + """ + # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` + # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` + # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will + # raise `Unauthorized` exception if authentication token is not provided. + try: + is_authenticated = current_user.is_authenticated + except Unauthorized: + is_authenticated = False + return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump() diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c800c0e4e1..49ff4f57dc 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -261,17 +261,6 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - args = {} - if "data" in request.form: - args = json.loads(request.form["data"]) - if "doc_form" not in args: - args["doc_form"] = "text_model" - if "doc_language" not in args: - args["doc_language"] = "English" - - # get dataset info - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -280,6 +269,18 @@ class DocumentAddByFileApi(DatasetApiResource): if dataset.provider == "external": raise ValueError("External datasets are not supported.") + args = {} + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = dataset.chunk_structure or "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" + + # get dataset info + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + indexing_technique = args.get("indexing_technique") or dataset.indexing_technique if not indexing_technique: raise ValueError("indexing_technique is required.") @@ -370,17 +371,6 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - args = {} - if "data" in request.form: - args = json.loads(request.form["data"]) - if "doc_form" not in args: - args["doc_form"] = "text_model" - if "doc_language" not in args: - args["doc_language"] = "English" - - # get dataset info - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: @@ -389,6 +379,18 @@ class DocumentUpdateByFileApi(DatasetApiResource): if dataset.provider == "external": raise ValueError("External datasets are not supported.") + args = {} + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = dataset.chunk_structure or "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" + + # get dataset info + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 0157521ae9..34d02a1e51 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, ) +from core.app.workflow.node_factory import DifyNodeFactory from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.enums import WorkflowType from core.workflow.graph import Graph from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.runtime import GraphRuntimeState, VariablePool diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 7adf3504ac..2ca153f835 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -25,6 +25,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine.layers.base import GraphEngineLayer @@ -53,7 +54,6 @@ from core.workflow.graph_events import ( ) from core.workflow.graph_events.graph import GraphRunAbortedEvent from core.workflow.nodes import NodeType -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 225b758fcb..a7ea9ef446 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -3,8 +3,8 @@ from datetime import UTC, datetime from typing import Any, ClassVar from pydantic import TypeAdapter -from sqlalchemy.orm import Session, sessionmaker +from core.db.session_factory import session_factory from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_events.base import GraphEngineEvent from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent @@ -31,13 +31,11 @@ class TriggerPostLayer(GraphEngineLayer): cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity, start_time: datetime, trigger_log_id: str, - session_maker: sessionmaker[Session], ): super().__init__() self.trigger_log_id = trigger_log_id self.start_time = start_time self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity - self.session_maker = session_maker def on_graph_start(self): pass @@ -47,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer): Update trigger log with success or failure. """ if isinstance(event, tuple(self._STATUS_MAP.keys())): - with self.session_maker() as session: + with session_factory.create_session() as session: repo = SQLAlchemyWorkflowTriggerLogRepository(session) trigger_log = repo.get_by_id(self.trigger_log_id) if not trigger_log: diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py new file mode 100644 index 0000000000..172ee5d703 --- /dev/null +++ b/api/core/app/workflow/__init__.py @@ -0,0 +1,3 @@ +from .node_factory import DifyNodeFactory + +__all__ = ["DifyNodeFactory"] diff --git a/api/core/workflow/nodes/node_factory.py b/api/core/app/workflow/node_factory.py similarity index 98% rename from api/core/workflow/nodes/node_factory.py rename to api/core/app/workflow/node_factory.py index 5c04e5110f..e0a0059a38 100644 --- a/api/core/workflow/nodes/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -15,6 +15,7 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.nodes.http_request.node import HttpRequestNode +from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol from core.workflow.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, @@ -23,8 +24,6 @@ from core.workflow.nodes.template_transform.template_renderer import ( from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from libs.typing import is_str, is_str_dict -from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING - if TYPE_CHECKING: from core.workflow.entities import GraphInitParams from core.workflow.runtime import GraphRuntimeState diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index f45f15a6da..84f5bf5512 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -35,7 +35,6 @@ from extensions.ext_database import db from extensions.ext_storage import storage from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog -from repositories.factory import DifyAPIRepositoryFactory from tasks.ops_trace_task import process_trace_tasks if TYPE_CHECKING: @@ -473,6 +472,9 @@ class TraceTask: if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: + # Lazy import to avoid circular import during module initialization + from repositories.factory import DifyAPIRepositoryFactory + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 13fd579e20..3f57a346cd 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,5 +1,6 @@ import contextlib import json +import logging from collections.abc import Generator, Iterable from copy import deepcopy from datetime import UTC, datetime @@ -36,6 +37,8 @@ from extensions.ext_database import db from models.enums import CreatorUserRole from models.model import Message, MessageFile +logger = logging.getLogger(__name__) + class ToolEngine: """ @@ -123,25 +126,31 @@ class ToolEngine: # transform tool invoke message to get LLM friendly message return plain_text, message_files, meta except ToolProviderCredentialValidationError as e: + logger.error(e, exc_info=True) error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: error_response = f"there is not a tool named {tool.entity.identity.name}" + logger.error(e, exc_info=True) agent_tool_callback.on_tool_error(e) except ToolParameterValidationError as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" agent_tool_callback.on_tool_error(e) + logger.error(e, exc_info=True) except ToolInvokeError as e: error_response = f"tool invoke error: {e}" agent_tool_callback.on_tool_error(e) + logger.error(e, exc_info=True) except ToolEngineInvokeError as e: meta = e.meta error_response = f"tool invoke error: {meta.error}" agent_tool_callback.on_tool_error(e) + logger.error(e, exc_info=True) return error_response, [], meta except Exception as e: error_response = f"unknown error: {e}" agent_tool_callback.on_tool_error(e) + logger.error(e, exc_info=True) return error_response, [], ToolInvokeMeta.error_instance(error_response) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 283744b43b..9c1ceff145 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from factories.file_factory import build_from_mapping -from libs.login import current_user from models import Account, Tenant from models.model import App, EndUser from models.workflow import Workflow @@ -28,21 +27,6 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) -def _try_resolve_user_from_request() -> Account | EndUser | None: - """ - Try to resolve user from Flask request context. - - Returns None if not in a request context or if user is not available. - """ - # Note: `current_user` is a LocalProxy. Never compare it with None directly. - # Use _get_current_object() to dereference the proxy - user = getattr(current_user, "_get_current_object", lambda: current_user)() - # Check if we got a valid user object - if user is not None and hasattr(user, "id"): - return user - return None - - class WorkflowTool(Tool): """ Workflow tool. @@ -223,12 +207,6 @@ class WorkflowTool(Tool): Returns: Account | EndUser | None: The resolved user object, or None if resolution fails. """ - # Try to resolve user from request context first - user = _try_resolve_user_from_request() - if user is not None: - return user - - # Fall back to database resolution return self._resolve_user_from_database(user_id=user_id) def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None: diff --git a/api/core/workflow/context/__init__.py b/api/core/workflow/context/__init__.py index 31e1f2c8d9..1237d6a017 100644 --- a/api/core/workflow/context/__init__.py +++ b/api/core/workflow/context/__init__.py @@ -7,16 +7,28 @@ execution in multi-threaded environments. from core.workflow.context.execution_context import ( AppContext, + ContextProviderNotFoundError, ExecutionContext, IExecutionContext, NullAppContext, capture_current_context, + read_context, + register_context, + register_context_capturer, + reset_context_provider, ) +from core.workflow.context.models import SandboxContext __all__ = [ "AppContext", + "ContextProviderNotFoundError", "ExecutionContext", "IExecutionContext", "NullAppContext", + "SandboxContext", "capture_current_context", + "read_context", + "register_context", + "register_context_capturer", + "reset_context_provider", ] diff --git a/api/core/workflow/context/execution_context.py b/api/core/workflow/context/execution_context.py index 5a4203be93..e3007530f0 100644 --- a/api/core/workflow/context/execution_context.py +++ b/api/core/workflow/context/execution_context.py @@ -3,10 +3,13 @@ Execution Context - Abstracted context management for workflow execution. """ import contextvars +import threading from abc import ABC, abstractmethod -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import AbstractContextManager, contextmanager -from typing import Any, Protocol, final, runtime_checkable +from typing import Any, Protocol, TypeVar, final, runtime_checkable + +from pydantic import BaseModel class AppContext(ABC): @@ -86,6 +89,7 @@ class ExecutionContext: self._app_context = app_context self._context_vars = context_vars self._user = user + self._local = threading.local() @property def app_context(self) -> AppContext | None: @@ -123,14 +127,16 @@ class ExecutionContext: def __enter__(self) -> "ExecutionContext": """Enter the execution context.""" - self._cm = self.enter() - self._cm.__enter__() + cm = self.enter() + self._local.cm = cm + cm.__enter__() return self def __exit__(self, *args: Any) -> None: """Exit the execution context.""" - if hasattr(self, "_cm"): - self._cm.__exit__(*args) + cm = getattr(self._local, "cm", None) + if cm is not None: + cm.__exit__(*args) class NullAppContext(AppContext): @@ -204,13 +210,75 @@ class ExecutionContextBuilder: ) +_capturer: Callable[[], IExecutionContext] | None = None + +# Tenant-scoped providers using tuple keys for clarity and constant-time lookup. +# Key mapping: +# (name, tenant_id) -> provider +# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox") +# - tenant_id: tenant identifier string +# Value: +# provider: Callable[[], BaseModel] returning the typed context value +# Type-safety note: +# - This registry cannot enforce that all providers for a given name return the same BaseModel type. +# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice), +# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and +# def read_sandbox_ctx(tenant_id: str) -> SandboxContext. +_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} + +T = TypeVar("T", bound=BaseModel) + + +class ContextProviderNotFoundError(KeyError): + """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id).""" + + pass + + +def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: + """Register a single enterable execution context capturer (e.g., Flask).""" + global _capturer + _capturer = capturer + + +def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None: + """Register a tenant-specific provider for a named context. + + Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions. + Consider adding a typed wrapper for this registration in your feature module. + """ + _tenant_context_providers[(name, tenant_id)] = provider + + +def read_context(name: str, *, tenant_id: str) -> BaseModel: + """ + Read a context value for a specific tenant. + + Raises KeyError if the provider for (name, tenant_id) is not registered. + """ + prov = _tenant_context_providers.get((name, tenant_id)) + if prov is None: + raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'") + return prov() + + def capture_current_context() -> IExecutionContext: """ Capture current execution context from the calling environment. - Returns: - IExecutionContext with captured context + If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal + context with NullAppContext + copy of current contextvars. """ - from context import capture_current_context + if _capturer is None: + return ExecutionContext( + app_context=NullAppContext(), + context_vars=contextvars.copy_context(), + ) + return _capturer() - return capture_current_context() + +def reset_context_provider() -> None: + """Reset the capturer and all tenant-scoped context providers (primarily for tests).""" + global _capturer + _capturer = None + _tenant_context_providers.clear() diff --git a/api/core/workflow/context/models.py b/api/core/workflow/context/models.py new file mode 100644 index 0000000000..af5a4b2614 --- /dev/null +++ b/api/core/workflow/context/models.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from pydantic import AnyHttpUrl, BaseModel + + +class SandboxContext(BaseModel): + """Typed context for sandbox integration. All fields optional by design.""" + + sandbox_url: AnyHttpUrl | None = None + sandbox_token: str | None = None # optional, if later needed for auth + + +__all__ = ["SandboxContext"] diff --git a/api/core/workflow/graph_engine/worker.py b/api/core/workflow/graph_engine/worker.py index 95db5c5c92..6c69ea5df0 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/core/workflow/graph_engine/worker.py @@ -11,7 +11,6 @@ import time from collections.abc import Sequence from datetime import datetime from typing import TYPE_CHECKING, final -from uuid import uuid4 from typing_extensions import override @@ -113,7 +112,7 @@ class Worker(threading.Thread): self._ready_queue.task_done() except Exception as e: error_event = NodeRunFailedEvent( - id=str(uuid4()), + id=node.execution_id, node_id=node.id, node_type=node.node_type, in_iteration_id=None, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 234651ce96..bf3c045fd6 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -235,7 +235,18 @@ class AgentNode(Node[AgentNodeData]): 0, ): value_param = param.get("value", {}) - params[key] = value_param.get("value", "") if value_param is not None else None + if value_param and value_param.get("type", "") == "variable": + variable_selector = value_param.get("value") + if not variable_selector: + raise ValueError("Variable selector is missing for a variable-type parameter.") + + variable = variable_pool.get(variable_selector) + if variable is None: + raise AgentVariableNotFoundError(str(variable_selector)) + + params[key] = variable.value + else: + params[key] = value_param.get("value", "") if value_param is not None else None else: params[key] = None parameters = params diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 55c8db40ea..63e0260341 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -469,12 +469,8 @@ class Node(Generic[NodeDataT]): import core.workflow.nodes as _nodes_pkg for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."): - # Avoid importing modules that depend on the registry to prevent circular imports - # e.g. node_factory imports node_mapping which builds the mapping here. - if _modname in { - "core.workflow.nodes.node_factory", - "core.workflow.nodes.node_mapping", - }: + # Avoid importing modules that depend on the registry to prevent circular imports. + if _modname == "core.workflow.nodes.node_mapping": continue importlib.import_module(_modname) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 569a4196fb..ced996e7e0 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -588,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _create_graph_engine(self, index: int, item: object): # Import dependencies + from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState # Create GraphInitParams from node attributes diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 1f9fc8a115..07d05966cc 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -413,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): def _create_graph_engine(self, start_at: datetime, root_node_id: str): # Import dependencies + from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState # Create GraphInitParams from node attributes diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index ee37314721..c7bcc66c8b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -7,6 +7,7 @@ from typing import Any from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.file.models import File from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams @@ -19,7 +20,6 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from core.workflow.nodes import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 51e2c6cdd5..46885761a1 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + archive_workflow_runs, clean_expired_messages, clean_workflow_runs, cleanup_orphaned_draft_variables, @@ -11,6 +12,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, convert_to_agent_apps, create_tenant, + delete_archived_workflow_runs, extract_plugins, extract_unique_plugins, file_usage, @@ -24,6 +26,7 @@ def init_app(app: DifyApp): reset_email, reset_encrypt_key_pair, reset_password, + restore_workflow_runs, setup_datasource_oauth_client, setup_system_tool_oauth_client, setup_system_trigger_oauth_client, @@ -58,6 +61,9 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + archive_workflow_runs, + delete_archived_workflow_runs, + restore_workflow_runs, clean_workflow_runs, clean_expired_messages, ] diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 0ebc03a98c..ae70356322 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -2,7 +2,12 @@ from flask_restx import Namespace, fields from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields from fields.member_fields import build_simple_account_model, simple_account_fields -from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields +from fields.workflow_run_fields import ( + build_workflow_run_for_archived_log_model, + build_workflow_run_for_log_model, + workflow_run_for_archived_log_fields, + workflow_run_for_log_fields, +) from libs.helper import TimestampField workflow_app_log_partial_fields = { @@ -34,6 +39,33 @@ def build_workflow_app_log_partial_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowAppLogPartial", copied_fields) +workflow_archived_log_partial_fields = { + "id": fields.String, + "workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True), + "trigger_metadata": fields.Raw, + "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), + "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), + "created_at": TimestampField, +} + + +def build_workflow_archived_log_partial_model(api_or_ns: Namespace): + """Build the workflow archived log partial model for the API or Namespace.""" + workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) + simple_account_model = build_simple_account_model(api_or_ns) + simple_end_user_model = build_simple_end_user_model(api_or_ns) + + copied_fields = workflow_archived_log_partial_fields.copy() + copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) + copied_fields["created_by_account"] = fields.Nested( + simple_account_model, attribute="created_by_account", allow_null=True + ) + copied_fields["created_by_end_user"] = fields.Nested( + simple_end_user_model, attribute="created_by_end_user", allow_null=True + ) + return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) + + workflow_app_log_pagination_fields = { "page": fields.Integer, "limit": fields.Integer, @@ -51,3 +83,21 @@ def build_workflow_app_log_pagination_model(api_or_ns: Namespace): copied_fields = workflow_app_log_pagination_fields.copy() copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) return api_or_ns.model("WorkflowAppLogPagination", copied_fields) + + +workflow_archived_log_pagination_fields = { + "page": fields.Integer, + "limit": fields.Integer, + "total": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(workflow_archived_log_partial_fields)), +} + + +def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): + """Build the workflow archived log pagination model for the API or Namespace.""" + workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns) + + copied_fields = workflow_archived_log_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model)) + return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 476025064f..35bb442c59 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -23,6 +23,19 @@ def build_workflow_run_for_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) +workflow_run_for_archived_log_fields = { + "id": fields.String, + "status": fields.String, + "triggered_from": fields.String, + "elapsed_time": fields.Float, + "total_tokens": fields.Integer, +} + + +def build_workflow_run_for_archived_log_model(api_or_ns: Namespace): + return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields) + + workflow_run_for_list_fields = { "id": fields.String, "version": fields.String, diff --git a/api/libs/archive_storage.py b/api/libs/archive_storage.py index f84d226447..66b57ac661 100644 --- a/api/libs/archive_storage.py +++ b/api/libs/archive_storage.py @@ -7,7 +7,6 @@ to S3-compatible object storage. import base64 import datetime -import gzip import hashlib import logging from collections.abc import Generator @@ -39,7 +38,7 @@ class ArchiveStorage: """ S3-compatible storage client for archiving or exporting. - This client provides methods for storing and retrieving archived data in JSONL+gzip format. + This client provides methods for storing and retrieving archived data in JSONL format. """ def __init__(self, bucket: str): @@ -69,7 +68,10 @@ class ArchiveStorage: aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY, aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY, region_name=dify_config.ARCHIVE_STORAGE_REGION, - config=Config(s3={"addressing_style": "path"}), + config=Config( + s3={"addressing_style": "path"}, + max_pool_connections=64, + ), ) # Verify bucket accessibility @@ -100,12 +102,18 @@ class ArchiveStorage: """ checksum = hashlib.md5(data).hexdigest() try: - self.client.put_object( + response = self.client.put_object( Bucket=self.bucket, Key=key, Body=data, ContentMD5=self._content_md5(data), ) + etag = response.get("ETag") + if not etag: + raise ArchiveStorageError(f"Missing ETag for '{key}'") + normalized_etag = etag.strip('"') + if normalized_etag != checksum: + raise ArchiveStorageError(f"ETag mismatch for '{key}': expected={checksum}, actual={normalized_etag}") logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum) return checksum except ClientError as e: @@ -240,19 +248,18 @@ class ArchiveStorage: return base64.b64encode(hashlib.md5(data).digest()).decode() @staticmethod - def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes: + def serialize_to_jsonl(records: list[dict[str, Any]]) -> bytes: """ - Serialize records to gzipped JSONL format. + Serialize records to JSONL format. Args: records: List of dictionaries to serialize Returns: - Gzipped JSONL bytes + JSONL bytes """ lines = [] for record in records: - # Convert datetime objects to ISO format strings serialized = ArchiveStorage._serialize_record(record) lines.append(orjson.dumps(serialized)) @@ -260,23 +267,22 @@ class ArchiveStorage: if jsonl_content: jsonl_content += b"\n" - return gzip.compress(jsonl_content) + return jsonl_content @staticmethod - def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]: + def deserialize_from_jsonl(data: bytes) -> list[dict[str, Any]]: """ - Deserialize gzipped JSONL data to records. + Deserialize JSONL data to records. Args: - data: Gzipped JSONL bytes + data: JSONL bytes Returns: List of dictionaries """ - jsonl_content = gzip.decompress(data) records = [] - for line in jsonl_content.splitlines(): + for line in data.splitlines(): if line: records.append(orjson.loads(line)) diff --git a/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py new file mode 100644 index 0000000000..624be1d073 --- /dev/null +++ b/api/migrations/versions/2025_11_06_1603-9e6fa5cbcd80_make_message_annotation_question_not_.py @@ -0,0 +1,60 @@ +"""make message annotation question not nullable + +Revision ID: 9e6fa5cbcd80 +Revises: 03f8dcbc611e +Create Date: 2025-11-06 16:03:54.549378 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9e6fa5cbcd80' +down_revision = '288345cd01d1' +branch_labels = None +depends_on = None + + +def upgrade(): + bind = op.get_bind() + message_annotations = sa.table( + "message_annotations", + sa.column("id", sa.String), + sa.column("message_id", sa.String), + sa.column("question", sa.Text), + ) + messages = sa.table( + "messages", + sa.column("id", sa.String), + sa.column("query", sa.Text), + ) + update_question_from_message = ( + sa.update(message_annotations) + .where( + sa.and_( + message_annotations.c.question.is_(None), + message_annotations.c.message_id.isnot(None), + ) + ) + .values( + question=sa.select(sa.func.coalesce(messages.c.query, "")) + .where(messages.c.id == message_annotations.c.message_id) + .scalar_subquery() + ) + ) + bind.execute(update_question_from_message) + + fill_remaining_questions = ( + sa.update(message_annotations) + .where(message_annotations.c.question.is_(None)) + .values(question="") + ) + bind.execute(fill_remaining_questions) + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=False) + + +def downgrade(): + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=True) diff --git a/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py new file mode 100644 index 0000000000..b99ca04e3f --- /dev/null +++ b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py @@ -0,0 +1,73 @@ +"""add table explore banner and trial + +Revision ID: f9f6d18a37f9 +Revises: 9e6fa5cbcd80 +Create Date: 2026-01-017 11:10:18.079355 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f9f6d18a37f9' +down_revision = '9e6fa5cbcd80' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('account_trial_app_records', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('account_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('count', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'), + sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record') + ) + with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op: + batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False) + batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False) + + op.create_table('exporle_banners', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('content', sa.JSON(), nullable=False), + sa.Column('link', sa.String(length=255), nullable=False), + sa.Column('sort', sa.Integer(), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False), + sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey') + ) + op.create_table('trial_apps', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('trial_limit', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('id', name='trial_app_pkey'), + sa.UniqueConstraint('app_id', name='unique_trail_app_id') + ) + with op.batch_alter_table('trial_apps', schema=None) as batch_op: + batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('trial_apps', schema=None) as batch_op: + batch_op.drop_index('trial_app_tenant_id_idx') + batch_op.drop_index('trial_app_app_id_idx') + + op.drop_table('trial_apps') + op.drop_table('exporle_banners') + with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op: + batch_op.drop_index('account_trial_app_record_app_id_idx') + batch_op.drop_index('account_trial_app_record_account_id_idx') + + op.drop_table('account_trial_app_records') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py new file mode 100644 index 0000000000..5e7298af54 --- /dev/null +++ b/api/migrations/versions/2026_01_21_1718-9d77545f524e_add_workflow_archive_logs.py @@ -0,0 +1,95 @@ +"""create workflow_archive_logs + +Revision ID: 9d77545f524e +Revises: f9f6d18a37f9 +Create Date: 2026-01-06 17:18:56.292479 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +def _is_pg(conn): + return conn.dialect.name == "postgresql" + +# revision identifiers, used by Alembic. +revision = '9d77545f524e' +down_revision = 'f9f6d18a37f9' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if _is_pg(conn): + op.create_table('workflow_archive_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False), + sa.Column('log_id', models.types.StringUUID(), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('log_created_at', sa.DateTime(), nullable=True), + sa.Column('log_created_from', sa.String(length=255), nullable=True), + sa.Column('run_version', sa.String(length=255), nullable=False), + sa.Column('run_status', sa.String(length=255), nullable=False), + sa.Column('run_triggered_from', sa.String(length=255), nullable=False), + sa.Column('run_error', models.types.LongText(), nullable=True), + sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('run_created_at', sa.DateTime(), nullable=False), + sa.Column('run_finished_at', sa.DateTime(), nullable=True), + sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('trigger_metadata', models.types.LongText(), nullable=True), + sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey') + ) + else: + op.create_table('workflow_archive_logs', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('log_id', models.types.StringUUID(), nullable=True), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_id', models.types.StringUUID(), nullable=False), + sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False), + sa.Column('created_by_role', sa.String(length=255), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('log_created_at', sa.DateTime(), nullable=True), + sa.Column('log_created_from', sa.String(length=255), nullable=True), + sa.Column('run_version', sa.String(length=255), nullable=False), + sa.Column('run_status', sa.String(length=255), nullable=False), + sa.Column('run_triggered_from', sa.String(length=255), nullable=False), + sa.Column('run_error', models.types.LongText(), nullable=True), + sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False), + sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('run_created_at', sa.DateTime(), nullable=False), + sa.Column('run_finished_at', sa.DateTime(), nullable=True), + sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True), + sa.Column('trigger_metadata', models.types.LongText(), nullable=True), + sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey') + ) + with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op: + batch_op.create_index('workflow_archive_log_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_archive_log_run_created_at_idx', ['run_created_at'], unique=False) + batch_op.create_index('workflow_archive_log_workflow_run_id_idx', ['workflow_run_id'], unique=False) + + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op: + batch_op.drop_index('workflow_archive_log_workflow_run_id_idx') + batch_op.drop_index('workflow_archive_log_run_created_at_idx') + batch_op.drop_index('workflow_archive_log_app_idx') + + op.drop_table('workflow_archive_logs') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index e23de832dc..74b33130ef 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -35,6 +35,7 @@ from .enums import ( WorkflowTriggerStatus, ) from .model import ( + AccountTrialAppRecord, ApiRequest, ApiToken, App, @@ -47,6 +48,7 @@ from .model import ( DatasetRetrieverResource, DifySetup, EndUser, + ExporleBanner, IconType, InstalledApp, Message, @@ -62,6 +64,7 @@ from .model import ( TagBinding, TenantCreditPool, TraceAppConfig, + TrialApp, UploadFile, ) from .oauth import DatasourceOauthParamConfig, DatasourceProvider @@ -100,6 +103,7 @@ from .workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, + WorkflowArchiveLog, WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom, @@ -114,6 +118,7 @@ __all__ = [ "Account", "AccountIntegrate", "AccountStatus", + "AccountTrialAppRecord", "ApiRequest", "ApiToken", "ApiToolProvider", @@ -150,6 +155,7 @@ __all__ = [ "DocumentSegment", "Embedding", "EndUser", + "ExporleBanner", "ExternalKnowledgeApis", "ExternalKnowledgeBindings", "IconType", @@ -188,6 +194,7 @@ __all__ = [ "ToolLabelBinding", "ToolModelInvoke", "TraceAppConfig", + "TrialApp", "TriggerOAuthSystemClient", "TriggerOAuthTenantClient", "TriggerSubscription", @@ -197,6 +204,7 @@ __all__ = [ "Workflow", "WorkflowAppLog", "WorkflowAppLogCreatedFrom", + "WorkflowArchiveLog", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/model.py b/api/models/model.py index d6a0aa3bb3..2eda085c37 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -603,6 +603,64 @@ class InstalledApp(TypeBase): return tenant +class TrialApp(Base): + __tablename__ = "trial_apps" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="trial_app_pkey"), + sa.Index("trial_app_app_id_idx", "app_id"), + sa.Index("trial_app_tenant_id_idx", "tenant_id"), + sa.UniqueConstraint("app_id", name="unique_trail_app_id"), + ) + + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + app_id = mapped_column(StringUUID, nullable=False) + tenant_id = mapped_column(StringUUID, nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + trial_limit = mapped_column(sa.Integer, nullable=False, default=3) + + @property + def app(self) -> App | None: + app = db.session.query(App).where(App.id == self.app_id).first() + return app + + +class AccountTrialAppRecord(Base): + __tablename__ = "account_trial_app_records" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"), + sa.Index("account_trial_app_record_account_id_idx", "account_id"), + sa.Index("account_trial_app_record_app_id_idx", "app_id"), + sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), + ) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + account_id = mapped_column(StringUUID, nullable=False) + app_id = mapped_column(StringUUID, nullable=False) + count = mapped_column(sa.Integer, nullable=False, default=0) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + + @property + def app(self) -> App | None: + app = db.session.query(App).where(App.id == self.app_id).first() + return app + + @property + def user(self) -> Account | None: + user = db.session.query(Account).where(Account.id == self.account_id).first() + return user + + +class ExporleBanner(Base): + __tablename__ = "exporle_banners" + __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + content = mapped_column(sa.JSON, nullable=False) + link = mapped_column(String(255), nullable=False) + sort = mapped_column(sa.Integer, nullable=False) + status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying")) + + class OAuthProviderApp(TypeBase): """ Globally shared OAuth provider app information. @@ -1423,7 +1481,7 @@ class MessageAnnotation(Base): app_id: Mapped[str] = mapped_column(StringUUID) conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) message_id: Mapped[str | None] = mapped_column(StringUUID) - question: Mapped[str | None] = mapped_column(LongText, nullable=True) + question: Mapped[str] = mapped_column(LongText, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False) hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) diff --git a/api/models/workflow.py b/api/models/workflow.py index 2ff47e87b9..0efb3a4e44 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1163,6 +1163,69 @@ class WorkflowAppLog(TypeBase): } +class WorkflowArchiveLog(TypeBase): + """ + Workflow archive log. + + Stores essential workflow run snapshot data for archived app logs. + + Field sources: + - Shared fields (tenant/app/workflow/run ids, created_by*): from WorkflowRun for consistency. + - log_* fields: from WorkflowAppLog when present; null if the run has no app log. + - run_* fields: workflow run snapshot fields from WorkflowRun. + - trigger_metadata: snapshot from WorkflowTriggerLog when present. + """ + + __tablename__ = "workflow_archive_logs" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_archive_log_pkey"), + sa.Index("workflow_archive_log_app_idx", "tenant_id", "app_id"), + sa.Index("workflow_archive_log_workflow_run_id_idx", "workflow_run_id"), + sa.Index("workflow_archive_log_run_created_at_idx", "run_created_at"), + ) + + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False + ) + + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + + log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) + log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True) + + run_version: Mapped[str] = mapped_column(String(255), nullable=False) + run_status: Mapped[str] = mapped_column(String(255), nullable=False) + run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False) + run_error: Mapped[str | None] = mapped_column(LongText, nullable=True) + run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) + run_total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + run_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) + run_finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + run_exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + + trigger_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True) + archived_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + + @property + def workflow_run_summary(self) -> dict[str, Any]: + return { + "id": self.workflow_run_id, + "status": self.run_status, + "triggered_from": self.run_triggered_from, + "elapsed_time": self.run_elapsed_time, + "total_tokens": self.run_total_tokens, + } + + class ConversationVariable(TypeBase): __tablename__ = "workflow_conversation_variables" diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 479eb1ff54..5b3f635301 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from models.workflow import WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): @@ -209,3 +209,23 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr The number of executions deleted """ ... + + def get_offloads_by_execution_ids( + self, + session: Session, + node_execution_ids: Sequence[str], + ) -> Sequence[WorkflowNodeExecutionOffload]: + """ + Get offload records by node execution IDs. + + This method retrieves workflow node execution offload records + that belong to the given node execution IDs. + + Args: + session: The database session to use + node_execution_ids: List of node execution IDs to filter by + + Returns: + A sequence of WorkflowNodeExecutionOffload instances + """ + ... diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 1a2b84fdf9..1d3954571f 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -45,7 +45,7 @@ from core.workflow.enums import WorkflowType from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun +from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( AverageInteractionStats, @@ -270,6 +270,58 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_archived_run_ids( + self, + session: Session, + run_ids: Sequence[str], + ) -> set[str]: + """ + Fetch workflow run IDs that already have archive log records. + """ + ... + + def get_archived_logs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowArchiveLog]: + """ + Fetch archived workflow logs by time range for restore. + """ + ... + + def get_archived_log_by_run_id( + self, + run_id: str, + ) -> WorkflowArchiveLog | None: + """ + Fetch a workflow archive log by workflow run ID. + """ + ... + + def delete_archive_log_by_run_id( + self, + session: Session, + run_id: str, + ) -> int: + """ + Delete archive log by workflow run ID. + + Used after restoring a workflow run to remove the archive log record, + allowing the run to be archived again if needed. + + Args: + session: Database session + run_id: Workflow run ID + + Returns: + Number of records deleted (0 or 1) + """ + ... + def delete_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -282,6 +334,61 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_pause_records_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowPause]: + """ + Fetch workflow pause records by workflow run ID. + """ + ... + + def get_pause_reason_records_by_run_id( + self, + session: Session, + pause_ids: Sequence[str], + ) -> Sequence[WorkflowPauseReason]: + """ + Fetch workflow pause reason records by pause IDs. + """ + ... + + def get_app_logs_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowAppLog]: + """ + Fetch workflow app logs by workflow run ID. + """ + ... + + def create_archive_logs( + self, + session: Session, + run: WorkflowRun, + app_logs: Sequence[WorkflowAppLog], + trigger_metadata: str | None, + ) -> int: + """ + Create archive log records for a workflow run. + """ + ... + + def get_archived_runs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowRun]: + """ + Return workflow runs that already have archive logs, for cleanup of `workflow_runs`. + """ + ... + def count_runs_with_related( self, runs: Sequence[WorkflowRun], diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 4a7c975d2c..b19cc73bd1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -351,3 +351,27 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut ) return int(node_executions_count), int(offloads_count) + + @staticmethod + def get_by_run( + session: Session, + run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Fetch node executions for a run using workflow_run_id. + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def get_offloads_by_execution_ids( + self, + session: Session, + node_execution_ids: Sequence[str], + ) -> Sequence[WorkflowNodeExecutionOffload]: + if not node_execution_ids: + return [] + + stmt = select(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + return list(session.scalars(stmt)) diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 9d2d06e99f..d5214be042 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -40,14 +40,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.workflow import ( - WorkflowAppLog, - WorkflowPauseReason, - WorkflowRun, -) -from models.workflow import ( - WorkflowPause as WorkflowPauseModel, -) +from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( @@ -369,6 +362,53 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): return session.scalars(stmt).all() + def get_archived_run_ids( + self, + session: Session, + run_ids: Sequence[str], + ) -> set[str]: + if not run_ids: + return set() + + stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids)) + return set(session.scalars(stmt).all()) + + def get_archived_log_by_run_id( + self, + run_id: str, + ) -> WorkflowArchiveLog | None: + with self._session_maker() as session: + stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1) + return session.scalar(stmt) + + def delete_archive_log_by_run_id( + self, + session: Session, + run_id: str, + ) -> int: + stmt = delete(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id) + result = session.execute(stmt) + return cast(CursorResult, result).rowcount or 0 + + def get_pause_records_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowPause]: + stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def get_pause_reason_records_by_run_id( + self, + session: Session, + pause_ids: Sequence[str], + ) -> Sequence[WorkflowPauseReason]: + if not pause_ids: + return [] + + stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) + return list(session.scalars(stmt)) + def delete_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -396,9 +436,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))) app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0 - pause_ids = session.scalars( - select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) - ).all() + pause_stmt = select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids)) + pause_ids = session.scalars(pause_stmt).all() pause_reasons_deleted = 0 pauses_deleted = 0 @@ -407,7 +446,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) ) pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0 - pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids))) + pauses_result = session.execute(delete(WorkflowPause).where(WorkflowPause.id.in_(pause_ids))) pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0 trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0 @@ -427,6 +466,124 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): "pause_reasons": pause_reasons_deleted, } + def get_app_logs_by_run_id( + self, + session: Session, + run_id: str, + ) -> Sequence[WorkflowAppLog]: + stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id) + return list(session.scalars(stmt)) + + def create_archive_logs( + self, + session: Session, + run: WorkflowRun, + app_logs: Sequence[WorkflowAppLog], + trigger_metadata: str | None, + ) -> int: + if not app_logs: + archive_log = WorkflowArchiveLog( + log_id=None, + log_created_at=None, + log_created_from=None, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + workflow_run_id=run.id, + created_by_role=run.created_by_role, + created_by=run.created_by, + run_version=run.version, + run_status=run.status, + run_triggered_from=run.triggered_from, + run_error=run.error, + run_elapsed_time=run.elapsed_time, + run_total_tokens=run.total_tokens, + run_total_steps=run.total_steps, + run_created_at=run.created_at, + run_finished_at=run.finished_at, + run_exceptions_count=run.exceptions_count, + trigger_metadata=trigger_metadata, + ) + session.add(archive_log) + return 1 + + archive_logs = [ + WorkflowArchiveLog( + log_id=app_log.id, + log_created_at=app_log.created_at, + log_created_from=app_log.created_from, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + workflow_run_id=run.id, + created_by_role=run.created_by_role, + created_by=run.created_by, + run_version=run.version, + run_status=run.status, + run_triggered_from=run.triggered_from, + run_error=run.error, + run_elapsed_time=run.elapsed_time, + run_total_tokens=run.total_tokens, + run_total_steps=run.total_steps, + run_created_at=run.created_at, + run_finished_at=run.finished_at, + run_exceptions_count=run.exceptions_count, + trigger_metadata=trigger_metadata, + ) + for app_log in app_logs + ] + session.add_all(archive_logs) + return len(archive_logs) + + def get_archived_runs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowRun]: + """ + Retrieves WorkflowRun records by joining workflow_archive_logs. + + Used to identify runs that are already archived and ready for deletion. + """ + stmt = ( + select(WorkflowRun) + .join(WorkflowArchiveLog, WorkflowArchiveLog.workflow_run_id == WorkflowRun.id) + .where( + WorkflowArchiveLog.run_created_at >= start_date, + WorkflowArchiveLog.run_created_at < end_date, + ) + .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc()) + .limit(limit) + ) + if tenant_ids: + stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) + return list(session.scalars(stmt)) + + def get_archived_logs_by_time_range( + self, + session: Session, + tenant_ids: Sequence[str] | None, + start_date: datetime, + end_date: datetime, + limit: int, + ) -> Sequence[WorkflowArchiveLog]: + # Returns WorkflowArchiveLog rows directly; use this when workflow_runs may be deleted. + stmt = ( + select(WorkflowArchiveLog) + .where( + WorkflowArchiveLog.run_created_at >= start_date, + WorkflowArchiveLog.run_created_at < end_date, + ) + .order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc()) + .limit(limit) + ) + if tenant_ids: + stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids)) + return list(session.scalars(stmt)) + def count_runs_with_related( self, runs: Sequence[WorkflowRun], @@ -459,7 +616,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ) pause_ids = session.scalars( - select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids)) ).all() pauses_count = len(pause_ids) pause_reasons_count = 0 @@ -511,9 +668,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): ValueError: If workflow_run_id is invalid or workflow run doesn't exist RuntimeError: If workflow is already paused or in invalid state """ - previous_pause_model_query = select(WorkflowPauseModel).where( - WorkflowPauseModel.workflow_run_id == workflow_run_id - ) + previous_pause_model_query = select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id) with self._session_maker() as session, session.begin(): # Get the workflow run workflow_run = session.get(WorkflowRun, workflow_run_id) @@ -538,7 +693,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Upload the state file # Create the pause record - pause_model = WorkflowPauseModel() + pause_model = WorkflowPause() pause_model.id = str(uuidv7()) pause_model.workflow_id = workflow_run.workflow_id pause_model.workflow_run_id = workflow_run.id @@ -710,13 +865,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ with self._session_maker() as session, session.begin(): # Get the pause model by ID - pause_model = session.get(WorkflowPauseModel, pause_entity.id) + pause_model = session.get(WorkflowPause, pause_entity.id) if pause_model is None: raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}") self._delete_pause_model(session, pause_model) @staticmethod - def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel): + def _delete_pause_model(session: Session, pause_model: WorkflowPause): storage.delete(pause_model.state_object_key) # Delete the pause record @@ -751,15 +906,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): _limit: int = limit or 1000 pruned_record_ids: list[str] = [] cond = or_( - WorkflowPauseModel.created_at < expiration, + WorkflowPause.created_at < expiration, and_( - WorkflowPauseModel.resumed_at.is_not(null()), - WorkflowPauseModel.resumed_at < resumption_expiration, + WorkflowPause.resumed_at.is_not(null()), + WorkflowPause.resumed_at < resumption_expiration, ), ) # First, collect pause records to delete with their state files # Expired pauses (created before expiration time) - stmt = select(WorkflowPauseModel).where(cond).limit(_limit) + stmt = select(WorkflowPause).where(cond).limit(_limit) with self._session_maker(expire_on_commit=False) as session: # Old resumed pauses (resumed more than resumption_duration ago) @@ -770,7 +925,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Delete state files from storage for pause in pauses_to_delete: with self._session_maker(expire_on_commit=False) as session, session.begin(): - # todo: this issues a separate query for each WorkflowPauseModel record. + # todo: this issues a separate query for each WorkflowPause record. # consider batching this lookup. try: storage.delete(pause.state_object_key) @@ -1022,7 +1177,7 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity): def __init__( self, *, - pause_model: WorkflowPauseModel, + pause_model: WorkflowPause, reason_models: Sequence[WorkflowPauseReason], human_input_form: Sequence = (), ) -> None: diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index ebd3745d18..f3dc4cd60b 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -46,6 +46,11 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): return self.session.scalar(query) + def list_by_run_id(self, run_id: str) -> Sequence[WorkflowTriggerLog]: + """List trigger logs for a workflow run.""" + query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id) + return list(self.session.scalars(query).all()) + def get_failed_for_retry( self, tenant_id: str, max_retry_count: int = 3, limit: int = 100 ) -> Sequence[WorkflowTriggerLog]: diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index e85bba8823..be5f483b95 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -2,9 +2,11 @@ import logging import time import click +from redis.exceptions import LockError import app from configs import dify_config +from extensions.ext_redis import redis_client from services.retention.conversation.messages_clean_policy import create_message_clean_policy from services.retention.conversation.messages_clean_service import MessagesCleanService @@ -31,12 +33,16 @@ def clean_messages(): ) # Create and run the cleanup service - service = MessagesCleanService.from_days( - policy=policy, - days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, - batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, - ) - stats = service.run() + # lock the task to avoid concurrent execution in case of the future data volume growth + with redis_client.lock( + "retention:clean_messages", timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL, blocking=False + ): + service = MessagesCleanService.from_days( + policy=policy, + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + ) + stats = service.run() end_at = time.perf_counter() click.echo( @@ -50,6 +56,16 @@ def clean_messages(): fg="green", ) ) + except LockError: + end_at = time.perf_counter() + logger.exception("clean_messages: acquire task lock failed, skip current execution") + click.echo( + click.style( + f"clean_messages: skipped (lock already held) - latency: {end_at - start_at:.2f}s", + fg="yellow", + ) + ) + raise except Exception as e: end_at = time.perf_counter() logger.exception("clean_messages failed") diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py index 9f5bf8e150..ff45a3ddf2 100644 --- a/api/schedule/clean_workflow_runs_task.py +++ b/api/schedule/clean_workflow_runs_task.py @@ -1,11 +1,16 @@ +import logging from datetime import UTC, datetime import click +from redis.exceptions import LockError import app from configs import dify_config +from extensions.ext_redis import redis_client from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup +logger = logging.getLogger(__name__) + @app.celery.task(queue="retention") def clean_workflow_runs_task() -> None: @@ -25,19 +30,50 @@ def clean_workflow_runs_task() -> None: start_time = datetime.now(UTC) - WorkflowRunCleanup( - days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, - batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, - start_from=None, - end_before=None, - ).run() + try: + # lock the task to avoid concurrent execution in case of the future data volume growth + with redis_client.lock( + "retention:clean_workflow_runs_task", + timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL, + blocking=False, + ): + WorkflowRunCleanup( + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + start_from=None, + end_before=None, + ).run() - end_time = datetime.now(UTC) - elapsed = end_time - start_time - click.echo( - click.style( - f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} " - f"end={end_time.isoformat()} duration={elapsed}", - fg="green", + end_time = datetime.now(UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) ) - ) + except LockError: + end_time = datetime.now(UTC) + elapsed = end_time - start_time + logger.exception("clean_workflow_runs_task: acquire task lock failed, skip current execution") + click.echo( + click.style( + f"Scheduled workflow run cleanup skipped (lock already held). " + f"start={start_time.isoformat()} end={end_time.isoformat()} duration={elapsed}", + fg="yellow", + ) + ) + raise + except Exception as e: + end_time = datetime.now(UTC) + elapsed = end_time - start_time + logger.exception("clean_workflow_runs_task failed") + click.echo( + click.style( + f"Scheduled workflow run cleanup failed. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed} - {str(e)}", + fg="red", + ) + ) + raise diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index b73302508a..56e9cc6a00 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -209,8 +209,12 @@ class AppAnnotationService: if not app: raise NotFound("App not found") + question = args.get("question") + if question is None: + raise ValueError("'question' is required") + annotation = MessageAnnotation( - app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id + app_id=app.id, content=args["answer"], question=question, account_id=current_user.id ) db.session.add(annotation) db.session.commit() @@ -219,7 +223,7 @@ class AppAnnotationService: if annotation_setting: add_annotation_to_index_task.delay( annotation.id, - args["question"], + question, current_tenant_id, app_id, annotation_setting.collection_binding_id, @@ -244,8 +248,12 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") + question = args.get("question") + if question is None: + raise ValueError("'question' is required") + annotation.content = args["answer"] - annotation.question = args["question"] + annotation.question = question db.session.commit() # if annotation reply is enabled , add annotation to index diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 9b853b8337..b2fb3784e8 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -170,6 +170,8 @@ class SystemFeatureModel(BaseModel): plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() enable_change_email: bool = True plugin_manager: PluginManagerModel = PluginManagerModel() + enable_trial_app: bool = False + enable_explore_banner: bool = False class FeatureService: @@ -200,7 +202,7 @@ class FeatureService: return knowledge_rate_limit @classmethod - def get_system_features(cls) -> SystemFeatureModel: + def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() cls._fulfill_system_params_from_env(system_features) @@ -210,7 +212,7 @@ class FeatureService: system_features.webapp_auth.enabled = True system_features.enable_change_email = False system_features.plugin_manager.enabled = True - cls._fulfill_params_from_enterprise(system_features) + cls._fulfill_params_from_enterprise(system_features, is_authenticated) if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True @@ -225,6 +227,8 @@ class FeatureService: system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" + system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP + system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): @@ -306,7 +310,7 @@ class FeatureService: features.next_credit_reset_date = billing_info["next_credit_reset_date"] @classmethod - def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): + def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False): enterprise_info = EnterpriseService.get_info() if "SSOEnforcedForSignin" in enterprise_info: @@ -343,19 +347,14 @@ class FeatureService: ) features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") - if "License" in enterprise_info: - license_info = enterprise_info["License"] + if is_authenticated and (license_info := enterprise_info.get("License")): + features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + features.license.expired_at = license_info.get("expiredAt", "") - if "status" in license_info: - features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - - if "expiredAt" in license_info: - features.license.expired_at = license_info["expiredAt"] - - if "workspaces" in license_info: - features.license.workspaces.enabled = license_info["workspaces"]["enabled"] - features.license.workspaces.limit = license_info["workspaces"]["limit"] - features.license.workspaces.size = license_info["workspaces"]["used"] + if workspaces_info := license_info.get("workspaces"): + features.license.workspaces.enabled = workspaces_info.get("enabled", False) + features.license.workspaces.limit = workspaces_info.get("limit", 0) + features.license.workspaces.size = workspaces_info.get("used", 0) if "PluginInstallationPermission" in enterprise_info: plugin_installation_info = enterprise_info["PluginInstallationPermission"] diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 544383a106..6b211a5632 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,4 +1,7 @@ from configs import dify_config +from extensions.ext_database import db +from models.model import AccountTrialAppRecord, TrialApp +from services.feature_service import FeatureService from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory @@ -20,6 +23,15 @@ class RecommendedAppService: ) ) + if FeatureService.get_system_features().enable_trial_app: + apps = result["recommended_apps"] + for app in apps: + app_id = app["app_id"] + trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + if trial_app_model: + app["can_trial"] = True + else: + app["can_trial"] = False return result @classmethod @@ -32,4 +44,30 @@ class RecommendedAppService: mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() result: dict = retrieval_instance.get_recommend_app_detail(app_id) + if FeatureService.get_system_features().enable_trial_app: + app_id = result["id"] + trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first() + if trial_app_model: + result["can_trial"] = True + else: + result["can_trial"] = False return result + + @classmethod + def add_trial_app_record(cls, app_id: str, account_id: str): + """ + Add trial app record. + :param app_id: app id + :return: + """ + account_trial_app_record = ( + db.session.query(AccountTrialAppRecord) + .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) + .first() + ) + if account_trial_app_record: + account_trial_app_record.count += 1 + db.session.commit() + else: + db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id)) + db.session.commit() diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py index e69de29bb2..18dd42c91e 100644 --- a/api/services/retention/workflow_run/__init__.py +++ b/api/services/retention/workflow_run/__init__.py @@ -0,0 +1 @@ +"""Workflow run retention services.""" diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py new file mode 100644 index 0000000000..ea5cbb7740 --- /dev/null +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -0,0 +1,531 @@ +""" +Archive Paid Plan Workflow Run Logs Service. + +This service archives workflow run logs for paid plan users older than the configured +retention period (default: 90 days) to S3-compatible storage. + +Archived tables: +- workflow_runs +- workflow_app_logs +- workflow_node_executions +- workflow_node_execution_offload +- workflow_pauses +- workflow_pause_reasons +- workflow_trigger_logs + +""" + +import datetime +import io +import json +import logging +import time +import zipfile +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any + +import click +from sqlalchemy import inspect +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.workflow.enums import WorkflowType +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageNotConfiguredError, + get_archive_storage, +) +from models.workflow import WorkflowAppLog, WorkflowRun +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.billing_service import BillingService +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION + +logger = logging.getLogger(__name__) + + +@dataclass +class TableStats: + """Statistics for a single archived table.""" + + table_name: str + row_count: int + checksum: str + size_bytes: int + + +@dataclass +class ArchiveResult: + """Result of archiving a single workflow run.""" + + run_id: str + tenant_id: str + success: bool + tables: list[TableStats] = field(default_factory=list) + error: str | None = None + elapsed_time: float = 0.0 + + +@dataclass +class ArchiveSummary: + """Summary of the entire archive operation.""" + + total_runs_processed: int = 0 + runs_archived: int = 0 + runs_skipped: int = 0 + runs_failed: int = 0 + total_elapsed_time: float = 0.0 + + +class WorkflowRunArchiver: + """ + Archive workflow run logs for paid plan users. + + Storage Layout: + {tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/ + └── archive.v1.0.zip + ├── manifest.json + ├── workflow_runs.jsonl + ├── workflow_app_logs.jsonl + ├── workflow_node_executions.jsonl + ├── workflow_node_execution_offload.jsonl + ├── workflow_pauses.jsonl + ├── workflow_pause_reasons.jsonl + └── workflow_trigger_logs.jsonl + """ + + ARCHIVED_TYPE = [ + WorkflowType.WORKFLOW, + WorkflowType.RAG_PIPELINE, + ] + ARCHIVED_TABLES = [ + "workflow_runs", + "workflow_app_logs", + "workflow_node_executions", + "workflow_node_execution_offload", + "workflow_pauses", + "workflow_pause_reasons", + "workflow_trigger_logs", + ] + + start_from: datetime.datetime | None + end_before: datetime.datetime + + def __init__( + self, + days: int = 90, + batch_size: int = 100, + start_from: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + workers: int = 1, + tenant_ids: Sequence[str] | None = None, + limit: int | None = None, + dry_run: bool = False, + delete_after_archive: bool = False, + workflow_run_repo: APIWorkflowRunRepository | None = None, + ): + """ + Initialize the archiver. + + Args: + days: Archive runs older than this many days + batch_size: Number of runs to process per batch + start_from: Optional start time (inclusive) for archiving + end_before: Optional end time (exclusive) for archiving + workers: Number of concurrent workflow runs to archive + tenant_ids: Optional tenant IDs for grayscale rollout + limit: Maximum number of runs to archive (None for unlimited) + dry_run: If True, only preview without making changes + delete_after_archive: If True, delete runs and related data after archiving + """ + self.days = days + self.batch_size = batch_size + if start_from or end_before: + if start_from is None or end_before is None: + raise ValueError("start_from and end_before must be provided together") + if start_from >= end_before: + raise ValueError("start_from must be earlier than end_before") + self.start_from = start_from.replace(tzinfo=datetime.UTC) + self.end_before = end_before.replace(tzinfo=datetime.UTC) + else: + self.start_from = None + self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days) + if workers < 1: + raise ValueError("workers must be at least 1") + self.workers = workers + self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else [] + self.limit = limit + self.dry_run = dry_run + self.delete_after_archive = delete_after_archive + self.workflow_run_repo = workflow_run_repo + + def run(self) -> ArchiveSummary: + """ + Main archiving loop. + + Returns: + ArchiveSummary with statistics about the operation + """ + summary = ArchiveSummary() + start_time = time.time() + + click.echo( + click.style( + self._build_start_message(), + fg="white", + ) + ) + + # Initialize archive storage (will raise if not configured) + try: + if not self.dry_run: + storage = get_archive_storage() + else: + storage = None + except ArchiveStorageNotConfiguredError as e: + click.echo(click.style(f"Archive storage not configured: {e}", fg="red")) + return summary + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = self._get_workflow_run_repo() + + def _archive_with_session(run: WorkflowRun) -> ArchiveResult: + with session_maker() as session: + return self._archive_run(session, storage, run) + + last_seen: tuple[datetime.datetime, str] | None = None + archived_count = 0 + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + while True: + # Check limit + if self.limit and archived_count >= self.limit: + click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow")) + break + + # Fetch batch of runs + runs = self._get_runs_batch(last_seen) + + if not runs: + break + + run_ids = [run.id for run in runs] + with session_maker() as session: + archived_run_ids = repo.get_archived_run_ids(session, run_ids) + + last_seen = (runs[-1].created_at, runs[-1].id) + + # Filter to paid tenants only + tenant_ids = {run.tenant_id for run in runs} + paid_tenants = self._filter_paid_tenants(tenant_ids) + + runs_to_process: list[WorkflowRun] = [] + for run in runs: + summary.total_runs_processed += 1 + + # Skip non-paid tenants + if run.tenant_id not in paid_tenants: + summary.runs_skipped += 1 + continue + + # Skip already archived runs + if run.id in archived_run_ids: + summary.runs_skipped += 1 + continue + + # Check limit + if self.limit and archived_count + len(runs_to_process) >= self.limit: + break + + runs_to_process.append(run) + + if not runs_to_process: + continue + + results = list(executor.map(_archive_with_session, runs_to_process)) + + for run, result in zip(runs_to_process, results): + if result.success: + summary.runs_archived += 1 + archived_count += 1 + click.echo( + click.style( + f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} " + f"run {run.id} (tenant={run.tenant_id}, " + f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)", + fg="green", + ) + ) + else: + summary.runs_failed += 1 + click.echo( + click.style( + f"Failed to archive run {run.id}: {result.error}", + fg="red", + ) + ) + + summary.total_elapsed_time = time.time() - start_time + click.echo( + click.style( + f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: " + f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="white", + ) + ) + + return summary + + def _get_runs_batch( + self, + last_seen: tuple[datetime.datetime, str] | None, + ) -> Sequence[WorkflowRun]: + """Fetch a batch of workflow runs to archive.""" + repo = self._get_workflow_run_repo() + return repo.get_runs_batch_by_time_range( + start_from=self.start_from, + end_before=self.end_before, + last_seen=last_seen, + batch_size=self.batch_size, + run_types=self.ARCHIVED_TYPE, + tenant_ids=self.tenant_ids or None, + ) + + def _build_start_message(self) -> str: + range_desc = f"before {self.end_before.isoformat()}" + if self.start_from: + range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}" + return ( + f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving " + f"for runs {range_desc} " + f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})" + ) + + def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]: + """Filter tenant IDs to only include paid tenants.""" + if not dify_config.BILLING_ENABLED: + # If billing is not enabled, treat all tenants as paid + return tenant_ids + + if not tenant_ids: + return set() + + try: + bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids)) + except Exception: + logger.exception("Failed to fetch billing plans for tenants") + # On error, skip all tenants in this batch + return set() + + # Filter to paid tenants (any plan except SANDBOX) + paid = set() + for tid, info in bulk_info.items(): + if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM): + paid.add(tid) + + return paid + + def _archive_run( + self, + session: Session, + storage: ArchiveStorage | None, + run: WorkflowRun, + ) -> ArchiveResult: + """Archive a single workflow run.""" + start_time = time.time() + result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + + try: + # Extract data from all tables + table_data, app_logs, trigger_metadata = self._extract_data(session, run) + + if self.dry_run: + # In dry run, just report what would be archived + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + result.tables.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum="", + size_bytes=0, + ) + ) + result.success = True + else: + if storage is None: + raise ArchiveStorageNotConfiguredError("Archive storage not configured") + archive_key = self._get_archive_key(run) + + # Serialize tables for the archive bundle + table_stats: list[TableStats] = [] + table_payloads: dict[str, bytes] = {} + for table_name in self.ARCHIVED_TABLES: + records = table_data.get(table_name, []) + data = ArchiveStorage.serialize_to_jsonl(records) + table_payloads[table_name] = data + checksum = ArchiveStorage.compute_checksum(data) + + table_stats.append( + TableStats( + table_name=table_name, + row_count=len(records), + checksum=checksum, + size_bytes=len(data), + ) + ) + + # Generate and upload archive bundle + manifest = self._generate_manifest(run, table_stats) + manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8") + archive_data = self._build_archive_bundle(manifest_data, table_payloads) + storage.put_object(archive_key, archive_data) + + repo = self._get_workflow_run_repo() + archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata) + session.commit() + + deleted_counts = None + if self.delete_after_archive: + deleted_counts = repo.delete_runs_with_related( + [run], + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + + logger.info( + "Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s", + run.id, + {s.table_name: s.row_count for s in table_stats}, + archived_log_count, + deleted_counts, + ) + + result.tables = table_stats + result.success = True + + except Exception as e: + logger.exception("Failed to archive workflow run %s", run.id) + result.error = str(e) + session.rollback() + + result.elapsed_time = time.time() - start_time + return result + + def _extract_data( + self, + session: Session, + run: WorkflowRun, + ) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]: + table_data: dict[str, list[dict[str, Any]]] = {} + table_data["workflow_runs"] = [self._row_to_dict(run)] + repo = self._get_workflow_run_repo() + app_logs = repo.get_app_logs_by_run_id(session, run.id) + table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs] + node_exec_repo = self._get_workflow_node_execution_repo(session) + node_exec_records = node_exec_repo.get_executions_by_workflow_run( + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_run_id=run.id, + ) + node_exec_ids = [record.id for record in node_exec_records] + offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids) + table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records] + table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records] + repo = self._get_workflow_run_repo() + pause_records = repo.get_pause_records_by_run_id(session, run.id) + pause_ids = [pause.id for pause in pause_records] + pause_reason_records = repo.get_pause_reason_records_by_run_id( + session, + pause_ids, + ) + table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records] + table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records] + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + trigger_records = trigger_repo.list_by_run_id(run.id) + table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records] + trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None + return table_data, app_logs, trigger_metadata + + @staticmethod + def _row_to_dict(row: Any) -> dict[str, Any]: + mapper = inspect(row).mapper + return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns} + + def _get_archive_key(self, run: WorkflowRun) -> str: + """Get the storage key for the archive bundle.""" + created_at = run.created_at + prefix = ( + f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={run.id}" + ) + return f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + + def _generate_manifest( + self, + run: WorkflowRun, + table_stats: list[TableStats], + ) -> dict[str, Any]: + """Generate a manifest for the archived workflow run.""" + return { + "schema_version": ARCHIVE_SCHEMA_VERSION, + "workflow_run_id": run.id, + "tenant_id": run.tenant_id, + "app_id": run.app_id, + "workflow_id": run.workflow_id, + "created_at": run.created_at.isoformat(), + "archived_at": datetime.datetime.now(datetime.UTC).isoformat(), + "tables": { + stat.table_name: { + "row_count": stat.row_count, + "checksum": stat.checksum, + "size_bytes": stat.size_bytes, + } + for stat in table_stats + }, + } + + def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes: + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + archive.writestr("manifest.json", manifest_data) + for table_name in self.ARCHIVED_TABLES: + data = table_payloads.get(table_name) + if data is None: + raise ValueError(f"Missing archive payload for {table_name}") + archive.writestr(f"{table_name}.jsonl", data) + return buffer.getvalue() + + def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_ids = [run.id for run in runs] + return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids) + + def _get_workflow_node_execution_repo( + self, + session: Session, + ) -> DifyAPIWorkflowNodeExecutionRepository: + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False) + return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/constants.py b/api/services/retention/workflow_run/constants.py new file mode 100644 index 0000000000..162bb4947d --- /dev/null +++ b/api/services/retention/workflow_run/constants.py @@ -0,0 +1,2 @@ +ARCHIVE_SCHEMA_VERSION = "1.0" +ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip" diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py new file mode 100644 index 0000000000..11873bf1b9 --- /dev/null +++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py @@ -0,0 +1,134 @@ +""" +Delete Archived Workflow Run Service. + +This service deletes archived workflow run data from the database while keeping +archive logs intact. +""" + +import time +from collections.abc import Sequence +from dataclasses import dataclass, field +from datetime import datetime + +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository + + +@dataclass +class DeleteResult: + run_id: str + tenant_id: str + success: bool + deleted_counts: dict[str, int] = field(default_factory=dict) + error: str | None = None + elapsed_time: float = 0.0 + + +class ArchivedWorkflowRunDeletion: + def __init__(self, dry_run: bool = False): + self.dry_run = dry_run + self.workflow_run_repo: APIWorkflowRunRepository | None = None + + def delete_by_run_id(self, run_id: str) -> DeleteResult: + start_time = time.time() + result = DeleteResult(run_id=run_id, tenant_id="", success=False) + + repo = self._get_workflow_run_repo() + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + with session_maker() as session: + run = session.get(WorkflowRun, run_id) + if not run: + result.error = f"Workflow run {run_id} not found" + result.elapsed_time = time.time() - start_time + return result + + result.tenant_id = run.tenant_id + if not repo.get_archived_run_ids(session, [run.id]): + result.error = f"Workflow run {run_id} is not archived" + result.elapsed_time = time.time() - start_time + return result + + result = self._delete_run(run) + result.elapsed_time = time.time() - start_time + return result + + def delete_batch( + self, + tenant_ids: list[str] | None, + start_date: datetime, + end_date: datetime, + limit: int = 100, + ) -> list[DeleteResult]: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + results: list[DeleteResult] = [] + + repo = self._get_workflow_run_repo() + with session_maker() as session: + runs = list( + repo.get_archived_runs_by_time_range( + session=session, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + ) + for run in runs: + results.append(self._delete_run(run)) + + return results + + def _delete_run(self, run: WorkflowRun) -> DeleteResult: + start_time = time.time() + result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False) + if self.dry_run: + result.success = True + result.elapsed_time = time.time() - start_time + return result + + repo = self._get_workflow_run_repo() + try: + deleted_counts = repo.delete_runs_with_related( + [run], + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + result.deleted_counts = deleted_counts + result.success = True + except Exception as e: + result.error = str(e) + result.elapsed_time = time.time() - start_time + return result + + @staticmethod + def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + @staticmethod + def _delete_node_executions( + session: Session, + runs: Sequence[WorkflowRun], + ) -> tuple[int, int]: + from repositories.factory import DifyAPIRepositoryFactory + + run_ids = [run.id for run in runs] + repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False) + ) + return repo.delete_by_runs(session, run_ids) + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + from repositories.factory import DifyAPIRepositoryFactory + + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + return self.workflow_run_repo diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py new file mode 100644 index 0000000000..d4a6e87585 --- /dev/null +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -0,0 +1,481 @@ +""" +Restore Archived Workflow Run Service. + +This service restores archived workflow run data from S3-compatible storage +back to the database. +""" + +import io +import json +import logging +import time +import zipfile +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from datetime import datetime +from typing import Any, cast + +import click +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker + +from extensions.ext_database import db +from libs.archive_storage import ( + ArchiveStorage, + ArchiveStorageNotConfiguredError, + get_archive_storage, +) +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowArchiveLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.factory import DifyAPIRepositoryFactory +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME + +logger = logging.getLogger(__name__) + + +# Mapping of table names to SQLAlchemy models +TABLE_MODELS = { + "workflow_runs": WorkflowRun, + "workflow_app_logs": WorkflowAppLog, + "workflow_node_executions": WorkflowNodeExecutionModel, + "workflow_node_execution_offload": WorkflowNodeExecutionOffload, + "workflow_pauses": WorkflowPause, + "workflow_pause_reasons": WorkflowPauseReason, + "workflow_trigger_logs": WorkflowTriggerLog, +} + +SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]] + +SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = { + "1.0": {}, +} + + +@dataclass +class RestoreResult: + """Result of restoring a single workflow run.""" + + run_id: str + tenant_id: str + success: bool + restored_counts: dict[str, int] + error: str | None = None + elapsed_time: float = 0.0 + + +class WorkflowRunRestore: + """ + Restore archived workflow run data from storage to database. + + This service reads archived data from storage and restores it to the + database tables. It handles idempotency by skipping records that already + exist in the database. + """ + + def __init__(self, dry_run: bool = False, workers: int = 1): + """ + Initialize the restore service. + + Args: + dry_run: If True, only preview without making changes + workers: Number of concurrent workflow runs to restore + """ + self.dry_run = dry_run + if workers < 1: + raise ValueError("workers must be at least 1") + self.workers = workers + self.workflow_run_repo: APIWorkflowRunRepository | None = None + + def _restore_from_run( + self, + run: WorkflowRun | WorkflowArchiveLog, + *, + session_maker: sessionmaker, + ) -> RestoreResult: + start_time = time.time() + run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id + created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at + result = RestoreResult( + run_id=run_id, + tenant_id=run.tenant_id, + success=False, + restored_counts={}, + ) + + if not self.dry_run: + click.echo( + click.style( + f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})", + fg="white", + ) + ) + + try: + storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + result.error = str(e) + click.echo(click.style(f"Archive storage not configured: {e}", fg="red")) + result.elapsed_time = time.time() - start_time + return result + + prefix = ( + f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/" + f"month={created_at.strftime('%m')}/workflow_run_id={run_id}" + ) + archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}" + try: + archive_data = storage.get_object(archive_key) + except FileNotFoundError: + result.error = f"Archive bundle not found: {archive_key}" + click.echo(click.style(result.error, fg="red")) + result.elapsed_time = time.time() - start_time + return result + + with session_maker() as session: + try: + with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive: + try: + manifest = self._load_manifest_from_zip(archive) + except ValueError as e: + result.error = f"Archive bundle invalid: {e}" + click.echo(click.style(result.error, fg="red")) + return result + + tables = manifest.get("tables", {}) + schema_version = self._get_schema_version(manifest) + for table_name, info in tables.items(): + row_count = info.get("row_count", 0) + if row_count == 0: + result.restored_counts[table_name] = 0 + continue + + if self.dry_run: + result.restored_counts[table_name] = row_count + continue + + member_path = f"{table_name}.jsonl" + try: + data = archive.read(member_path) + except KeyError: + click.echo( + click.style( + f" Warning: Table data not found in archive: {member_path}", + fg="yellow", + ) + ) + result.restored_counts[table_name] = 0 + continue + + records = ArchiveStorage.deserialize_from_jsonl(data) + restored = self._restore_table_records( + session, + table_name, + records, + schema_version=schema_version, + ) + result.restored_counts[table_name] = restored + if not self.dry_run: + click.echo( + click.style( + f" Restored {restored}/{len(records)} records to {table_name}", + fg="white", + ) + ) + + # Verify row counts match manifest + manifest_total = sum(info.get("row_count", 0) for info in tables.values()) + restored_total = sum(result.restored_counts.values()) + + if not self.dry_run: + # Note: restored count might be less than manifest count if records already exist + logger.info( + "Restore verification: manifest_total=%d, restored_total=%d", + manifest_total, + restored_total, + ) + + # Delete the archive log record after successful restore + repo = self._get_workflow_run_repo() + repo.delete_archive_log_by_run_id(session, run_id) + + session.commit() + + result.success = True + if not self.dry_run: + click.echo( + click.style( + f"Completed restore for workflow run {run_id}: restored={result.restored_counts}", + fg="green", + ) + ) + + except Exception as e: + logger.exception("Failed to restore workflow run %s", run_id) + result.error = str(e) + session.rollback() + click.echo(click.style(f"Restore failed: {e}", fg="red")) + + result.elapsed_time = time.time() - start_time + return result + + def _get_workflow_run_repo(self) -> APIWorkflowRunRepository: + if self.workflow_run_repo is not None: + return self.workflow_run_repo + + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + return self.workflow_run_repo + + @staticmethod + def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]: + try: + data = archive.read("manifest.json") + except KeyError as e: + raise ValueError("manifest.json missing from archive bundle") from e + return json.loads(data.decode("utf-8")) + + def _restore_table_records( + self, + session: Session, + table_name: str, + records: list[dict[str, Any]], + *, + schema_version: str, + ) -> int: + """ + Restore records to a table. + + Uses INSERT ... ON CONFLICT DO NOTHING for idempotency. + + Args: + session: Database session + table_name: Name of the table + records: List of record dictionaries + schema_version: Archived schema version from manifest + + Returns: + Number of records actually inserted + """ + if not records: + return 0 + + model = TABLE_MODELS.get(table_name) + if not model: + logger.warning("Unknown table: %s", table_name) + return 0 + + column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model) + unknown_fields: set[str] = set() + + # Apply schema mapping, filter to current columns, then convert datetimes + converted_records = [] + for record in records: + mapped = self._apply_schema_mapping(table_name, schema_version, record) + unknown_fields.update(set(mapped.keys()) - column_names) + filtered = {key: value for key, value in mapped.items() if key in column_names} + for key in non_nullable_with_default: + if key in filtered and filtered[key] is None: + filtered.pop(key) + missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None] + if missing_required: + missing_cols = ", ".join(sorted(missing_required)) + raise ValueError( + f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}" + ) + converted = self._convert_datetime_fields(filtered, model) + converted_records.append(converted) + if unknown_fields: + logger.warning( + "Dropped unknown columns for %s (schema_version=%s): %s", + table_name, + schema_version, + ", ".join(sorted(unknown_fields)), + ) + + # Use INSERT ... ON CONFLICT DO NOTHING for idempotency + stmt = pg_insert(model).values(converted_records) + stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) + + result = session.execute(stmt) + return cast(CursorResult, result).rowcount or 0 + + def _convert_datetime_fields( + self, + record: dict[str, Any], + model: type[DeclarativeBase] | Any, + ) -> dict[str, Any]: + """Convert ISO datetime strings to datetime objects.""" + from sqlalchemy import DateTime + + result = dict(record) + + for column in model.__table__.columns: + if isinstance(column.type, DateTime): + value = result.get(column.key) + if isinstance(value, str): + try: + result[column.key] = datetime.fromisoformat(value) + except ValueError: + pass + + return result + + def _get_schema_version(self, manifest: dict[str, Any]) -> str: + schema_version = manifest.get("schema_version") + if not schema_version: + logger.warning("Manifest missing schema_version; defaulting to 1.0") + schema_version = "1.0" + schema_version = str(schema_version) + if schema_version not in SCHEMA_MAPPERS: + raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.") + return schema_version + + def _apply_schema_mapping( + self, + table_name: str, + schema_version: str, + record: dict[str, Any], + ) -> dict[str, Any]: + # Keep hook for forward/backward compatibility when schema evolves. + mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name) + if mapper is None: + return dict(record) + return mapper(record) + + def _get_model_column_info( + self, + model: type[DeclarativeBase] | Any, + ) -> tuple[set[str], set[str], set[str]]: + columns = list(model.__table__.columns) + column_names = {column.key for column in columns} + required_columns = { + column.key + for column in columns + if not column.nullable + and column.default is None + and column.server_default is None + and not column.autoincrement + } + non_nullable_with_default = { + column.key + for column in columns + if not column.nullable + and (column.default is not None or column.server_default is not None or column.autoincrement) + } + return column_names, required_columns, non_nullable_with_default + + def restore_batch( + self, + tenant_ids: list[str] | None, + start_date: datetime, + end_date: datetime, + limit: int = 100, + ) -> list[RestoreResult]: + """ + Restore multiple workflow runs by time range. + + Args: + tenant_ids: Optional tenant IDs + start_date: Start date filter + end_date: End date filter + limit: Maximum number of runs to restore (default: 100) + + Returns: + List of RestoreResult objects + """ + results: list[RestoreResult] = [] + if tenant_ids is not None and not tenant_ids: + return results + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = self._get_workflow_run_repo() + + with session_maker() as session: + archive_logs = repo.get_archived_logs_by_time_range( + session=session, + tenant_ids=tenant_ids, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + click.echo( + click.style( + f"Found {len(archive_logs)} archived workflow runs to restore", + fg="white", + ) + ) + + def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult: + return self._restore_from_run( + archive_log, + session_maker=session_maker, + ) + + with ThreadPoolExecutor(max_workers=self.workers) as executor: + results = list(executor.map(_restore_with_session, archive_logs)) + + total_counts: dict[str, int] = {} + for result in results: + for table_name, count in result.restored_counts.items(): + total_counts[table_name] = total_counts.get(table_name, 0) + count + success_count = sum(1 for result in results if result.success) + + if self.dry_run: + click.echo( + click.style( + f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}", + fg="yellow", + ) + ) + else: + click.echo( + click.style( + f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}", + fg="green", + ) + ) + + return results + + def restore_by_run_id( + self, + run_id: str, + ) -> RestoreResult: + """ + Restore a single workflow run by run ID. + """ + repo = self._get_workflow_run_repo() + archive_log = repo.get_archived_log_by_run_id(run_id) + + if not archive_log: + click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red")) + return RestoreResult( + run_id=run_id, + tenant_id="", + success=False, + restored_counts={}, + error=f"Workflow run archive {run_id} not found", + ) + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + result = self._restore_from_run(archive_log, session_maker=session_maker) + if self.dry_run and result.success: + click.echo( + click.style( + f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}", + fg="yellow", + ) + ) + return result diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 8574d30255..efc76c33bc 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from core.workflow.enums import WorkflowExecutionStatus -from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun +from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog from services.plugin.plugin_service import PluginService @@ -173,7 +173,80 @@ class WorkflowAppService: "data": items, } - def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]: + def get_paginate_workflow_archive_logs( + self, + *, + session: Session, + app_model: App, + page: int = 1, + limit: int = 20, + ): + """ + Get paginate workflow archive logs using SQLAlchemy 2.0 style. + """ + stmt = select(WorkflowArchiveLog).where( + WorkflowArchiveLog.tenant_id == app_model.tenant_id, + WorkflowArchiveLog.app_id == app_model.id, + WorkflowArchiveLog.log_id.isnot(None), + ) + + stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc()) + + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = session.scalar(count_stmt) or 0 + + offset_stmt = stmt.offset((page - 1) * limit).limit(limit) + + logs = list(session.scalars(offset_stmt).all()) + account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT} + end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER} + + accounts_by_id = {} + if account_ids: + accounts_by_id = { + account.id: account + for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all() + } + + end_users_by_id = {} + if end_user_ids: + end_users_by_id = { + end_user.id: end_user + for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all() + } + + items = [] + for log in logs: + if log.created_by_role == CreatorUserRole.ACCOUNT: + created_by_account = accounts_by_id.get(log.created_by) + created_by_end_user = None + elif log.created_by_role == CreatorUserRole.END_USER: + created_by_account = None + created_by_end_user = end_users_by_id.get(log.created_by) + else: + created_by_account = None + created_by_end_user = None + + items.append( + { + "id": log.id, + "workflow_run": log.workflow_run_summary, + "trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata), + "created_by_account": created_by_account, + "created_by_end_user": created_by_end_user, + "created_at": log.log_created_at, + } + ) + + return { + "page": page, + "limit": limit, + "total": total, + "has_more": total > page * limit, + "data": items, + } + + def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]: metadata: dict[str, Any] | None = self._safe_json_loads(meta_val) if not metadata: return {} diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index e7dead8a56..62e6497e9d 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,11 +4,11 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DatasetAutoDisableLog, DocumentSegment @@ -28,106 +28,106 @@ def add_document_to_index_task(dataset_document_id: str): logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green")) start_at = time.perf_counter() - dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() - if not dataset_document: - logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) - db.session.close() - return + with session_factory.create_session() as session: + dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() + if not dataset_document: + logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) + return - if dataset_document.indexing_status != "completed": - db.session.close() - return + if dataset_document.indexing_status != "completed": + return - indexing_cache_key = f"document_{dataset_document.id}_indexing" + indexing_cache_key = f"document_{dataset_document.id}_indexing" - try: - dataset = dataset_document.dataset - if not dataset: - raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.") + try: + dataset = dataset_document.dataset + if not dataset: + raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.") - segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == "completed", + ) + .order_by(DocumentSegment.position.asc()) + .all() ) - .order_by(DocumentSegment.position.asc()) - .all() - ) - documents = [] - multimodal_documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, + documents = [] + multimodal_documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + documents.append(document) + + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) + + # delete auto disable log + session.query(DatasetAutoDisableLog).where( + DatasetAutoDisableLog.document_id == dataset_document.id + ).delete() + + # update segment to enable + session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( + { + DocumentSegment.enabled: True, + DocumentSegment.disabled_at: None, + DocumentSegment.disabled_by: None, + DocumentSegment.updated_at: naive_utc_now(), + } ) - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - child_documents.append(child_document) - document.children = child_documents - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodal_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - documents.append(document) + session.commit() - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) - - # delete auto disable log - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() - - # update segment to enable - db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( - { - DocumentSegment.enabled: True, - DocumentSegment.disabled_at: None, - DocumentSegment.disabled_by: None, - DocumentSegment.updated_at: naive_utc_now(), - } - ) - db.session.commit() - - end_at = time.perf_counter() - logger.info( - click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") - ) - except Exception as e: - logger.exception("add document to index failed") - dataset_document.enabled = False - dataset_document.disabled_at = naive_utc_now() - dataset_document.indexing_status = "error" - dataset_document.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + end_at = time.perf_counter() + logger.info( + click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") + ) + except Exception as e: + logger.exception("add document to index failed") + dataset_document.enabled = False + dataset_document.disabled_at = naive_utc_now() + dataset_document.indexing_status = "error" + dataset_document.error = str(e) + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 775814318b..fc6bf03454 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -5,9 +5,9 @@ import click from celery import shared_task from werkzeug.exceptions import NotFound +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -32,74 +32,72 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" active_jobs_key = f"annotation_import_active:{tenant_id}" - # get app info - app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + with session_factory.create_session() as session: + # get app info + app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - if app: - try: - documents = [] - for content in content_list: - annotation = MessageAnnotation( - app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id + if app: + try: + documents = [] + for content in content_list: + annotation = MessageAnnotation( + app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id + ) + session.add(annotation) + session.flush() + + document = Document( + page_content=content["question"], + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, + ) + documents.append(document) + # if annotation reply is enabled , batch add annotations' index + app_annotation_setting = ( + session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) - db.session.add(annotation) - db.session.flush() - document = Document( - page_content=content["question"], - metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, - ) - documents.append(document) - # if annotation reply is enabled , batch add annotations' index - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() - ) + if app_annotation_setting: + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + app_annotation_setting.collection_binding_id, "annotation" + ) + ) + if not dataset_collection_binding: + raise NotFound("App annotation setting not found") + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id, + ) - if app_annotation_setting: - dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - app_annotation_setting.collection_binding_id, "annotation" + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.create(documents, duplicate_check=True) + + session.commit() + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + "Build index successful for batch import annotation: {} latency: {}".format( + job_id, end_at - start_at + ), + fg="green", ) ) - if not dataset_collection_binding: - raise NotFound("App annotation setting not found") - dataset = Dataset( - id=app_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - embedding_model_provider=dataset_collection_binding.provider_name, - embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id, - ) - - vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - vector.create(documents, duplicate_check=True) - - db.session.commit() - redis_client.setex(indexing_cache_key, 600, "completed") - end_at = time.perf_counter() - logger.info( - click.style( - "Build index successful for batch import annotation: {} latency: {}".format( - job_id, end_at - start_at - ), - fg="green", - ) - ) - except Exception as e: - db.session.rollback() - redis_client.setex(indexing_cache_key, 600, "error") - indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}" - redis_client.setex(indexing_error_msg_key, 600, str(e)) - logger.exception("Build index for batch import annotations failed") - finally: - # Clean up active job tracking to release concurrency slot - try: - redis_client.zrem(active_jobs_key, job_id) - logger.debug("Released concurrency slot for job: %s", job_id) - except Exception as cleanup_error: - # Log but don't fail if cleanup fails - the job will be auto-expired - logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error) - - # Close database session - db.session.close() + except Exception as e: + session.rollback() + redis_client.setex(indexing_cache_key, 600, "error") + indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}" + redis_client.setex(indexing_error_msg_key, 600, str(e)) + logger.exception("Build index for batch import annotations failed") + finally: + # Clean up active job tracking to release concurrency slot + try: + redis_client.zrem(active_jobs_key, job_id) + logger.debug("Released concurrency slot for job: %s", job_id) + except Exception as cleanup_error: + # Log but don't fail if cleanup fails - the job will be auto-expired + logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error) diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index c0020b29ed..7b5cd46b00 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -5,8 +5,8 @@ import click from celery import shared_task from sqlalchemy import exists, select +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation @@ -22,50 +22,55 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) - if not app: - logger.info(click.style(f"App not found: {app_id}", fg="red")) - db.session.close() - return + with session_factory.create_session() as session: + app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) + if not app: + logger.info(click.style(f"App not found: {app_id}", fg="red")) + return - app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() - - if not app_annotation_setting: - logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red")) - db.session.close() - return - - disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" - disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" - - try: - dataset = Dataset( - id=app_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - collection_binding_id=app_annotation_setting.collection_binding_id, + app_annotation_setting = ( + session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() ) + if not app_annotation_setting: + logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red")) + return + + disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" + disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" + try: - if annotations_exists: - vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - vector.delete() - except Exception: - logger.exception("Delete annotation index failed when annotation deleted.") - redis_client.setex(disable_app_annotation_job_key, 600, "completed") + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + collection_binding_id=app_annotation_setting.collection_binding_id, + ) - # delete annotation setting - db.session.delete(app_annotation_setting) - db.session.commit() + try: + if annotations_exists: + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete() + except Exception: + logger.exception("Delete annotation index failed when annotation deleted.") + redis_client.setex(disable_app_annotation_job_key, 600, "completed") - end_at = time.perf_counter() - logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("Annotation batch deleted index failed") - redis_client.setex(disable_app_annotation_job_key, 600, "error") - disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}" - redis_client.setex(disable_app_annotation_error_key, 600, str(e)) - finally: - redis_client.delete(disable_app_annotation_key) - db.session.close() + # delete annotation setting + session.delete(app_annotation_setting) + session.commit() + + end_at = time.perf_counter() + logger.info( + click.style( + f"App annotations index deleted : {app_id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception as e: + logger.exception("Annotation batch deleted index failed") + redis_client.setex(disable_app_annotation_job_key, 600, "error") + disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}" + redis_client.setex(disable_app_annotation_error_key, 600, str(e)) + finally: + redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index be1de3cdd2..4f8e2fec7a 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -5,9 +5,9 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset @@ -33,92 +33,98 @@ def enable_annotation_reply_task( logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + with session_factory.create_session() as session: + app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() - if not app: - logger.info(click.style(f"App not found: {app_id}", fg="red")) - db.session.close() - return + if not app: + logger.info(click.style(f"App not found: {app_id}", fg="red")) + return - annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() - enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" - enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" + annotations = session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() + enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" + enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" - try: - documents = [] - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, "annotation" - ) - annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() - if annotation_setting: - if dataset_collection_binding.id != annotation_setting.collection_binding_id: - old_dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - annotation_setting.collection_binding_id, "annotation" - ) - ) - if old_dataset_collection_binding and annotations: - old_dataset = Dataset( - id=app_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - embedding_model_provider=old_dataset_collection_binding.provider_name, - embedding_model=old_dataset_collection_binding.model_name, - collection_binding_id=old_dataset_collection_binding.id, - ) - - old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"]) - try: - old_vector.delete() - except Exception as e: - logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) - annotation_setting.score_threshold = score_threshold - annotation_setting.collection_binding_id = dataset_collection_binding.id - annotation_setting.updated_user_id = user_id - annotation_setting.updated_at = naive_utc_now() - db.session.add(annotation_setting) - else: - new_app_annotation_setting = AppAnnotationSetting( - app_id=app_id, - score_threshold=score_threshold, - collection_binding_id=dataset_collection_binding.id, - created_user_id=user_id, - updated_user_id=user_id, + try: + documents = [] + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_provider_name, embedding_model_name, "annotation" ) - db.session.add(new_app_annotation_setting) + annotation_setting = ( + session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + ) + if annotation_setting: + if dataset_collection_binding.id != annotation_setting.collection_binding_id: + old_dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + annotation_setting.collection_binding_id, "annotation" + ) + ) + if old_dataset_collection_binding and annotations: + old_dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=old_dataset_collection_binding.provider_name, + embedding_model=old_dataset_collection_binding.model_name, + collection_binding_id=old_dataset_collection_binding.id, + ) - dataset = Dataset( - id=app_id, - tenant_id=tenant_id, - indexing_technique="high_quality", - embedding_model_provider=embedding_provider_name, - embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id, - ) - if annotations: - for annotation in annotations: - document = Document( - page_content=annotation.question_text, - metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, + old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"]) + try: + old_vector.delete() + except Exception as e: + logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) + annotation_setting.score_threshold = score_threshold + annotation_setting.collection_binding_id = dataset_collection_binding.id + annotation_setting.updated_user_id = user_id + annotation_setting.updated_at = naive_utc_now() + session.add(annotation_setting) + else: + new_app_annotation_setting = AppAnnotationSetting( + app_id=app_id, + score_threshold=score_threshold, + collection_binding_id=dataset_collection_binding.id, + created_user_id=user_id, + updated_user_id=user_id, ) - documents.append(document) + session.add(new_app_annotation_setting) - vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - try: - vector.delete_by_metadata_field("app_id", app_id) - except Exception as e: - logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) - vector.create(documents) - db.session.commit() - redis_client.setex(enable_app_annotation_job_key, 600, "completed") - end_at = time.perf_counter() - logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("Annotation batch created index failed") - redis_client.setex(enable_app_annotation_job_key, 600, "error") - enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}" - redis_client.setex(enable_app_annotation_error_key, 600, str(e)) - db.session.rollback() - finally: - redis_client.delete(enable_app_annotation_key) - db.session.close() + dataset = Dataset( + id=app_id, + tenant_id=tenant_id, + indexing_technique="high_quality", + embedding_model_provider=embedding_provider_name, + embedding_model=embedding_model_name, + collection_binding_id=dataset_collection_binding.id, + ) + if annotations: + for annotation in annotations: + document = Document( + page_content=annotation.question_text, + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, + ) + documents.append(document) + + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + try: + vector.delete_by_metadata_field("app_id", app_id) + except Exception as e: + logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) + vector.create(documents) + session.commit() + redis_client.setex(enable_app_annotation_job_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + f"App annotations added to index: {app_id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception as e: + logger.exception("Annotation batch created index failed") + redis_client.setex(enable_app_annotation_job_key, 600, "error") + enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}" + redis_client.setex(enable_app_annotation_error_key, 600, str(e)) + session.rollback() + finally: + redis_client.delete(enable_app_annotation_key) diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index f8aac5b469..b51884148e 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -10,13 +10,13 @@ from typing import Any from celery import shared_task from sqlalchemy import select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session from configs import dify_config from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom from core.app.layers.trigger_post_layer import TriggerPostLayer -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import Account from models.enums import CreatorUserRole, WorkflowTriggerStatus from models.model import App, EndUser, Tenant @@ -98,10 +98,7 @@ def _execute_workflow_common( ): """Execute workflow with common logic and trigger log updates.""" - # Create a new session for this task - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - with session_factory() as session: + with session_factory.create_session() as session: trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) # Get trigger log @@ -157,7 +154,7 @@ def _execute_workflow_common( root_node_id=trigger_data.root_node_id, graph_engine_layers=[ # TODO: Re-enable TimeSliceLayer after the HITL release. - TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory), + TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id), ], ) diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 3e1bd16cc7..74b939e84d 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment from models.model import UploadFile @@ -28,65 +28,64 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form """ logger.info(click.style("Start batch clean documents when documents deleted", fg="green")) start_at = time.perf_counter() + if not doc_form: + raise ValueError("doc_form is required") - try: - if not doc_form: - raise ValueError("doc_form is required") - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Document has no dataset") + if not dataset: + raise Exception("Document has no dataset") - db.session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id.in_(document_ids), - ).delete(synchronize_session=False) + session.query(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id.in_(document_ids), + ).delete(synchronize_session=False) - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) - ).all() - # check segment is exist - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() + for image_file in image_files: + try: + if image_file and image_file.key: + storage.delete(image_file.key) + except Exception: + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file.id, + ) + stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(stmt) + session.delete(segment) + if file_ids: + files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() + for file in files: try: - if image_file and image_file.key: - storage.delete(image_file.key) + storage.delete(file.key) except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - upload_file_id, - ) - db.session.delete(image_file) - db.session.delete(segment) + logger.exception("Delete file failed when document deleted, file_id: %s", file.id) + stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) + session.execute(stmt) - db.session.commit() - if file_ids: - files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() - for file in files: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file.id) - db.session.delete(file) + session.commit() - db.session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned documents when documents deleted latency: {end_at - start_at}", - fg="green", + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned documents when documents deleted latency: {end_at - start_at}", + fg="green", + ) ) - ) - except Exception: - logger.exception("Cleaned documents when documents deleted failed") - finally: - db.session.close() + except Exception: + logger.exception("Cleaned documents when documents deleted failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index bd95af2614..8ee09d5738 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -9,9 +9,9 @@ import pandas as pd from celery import shared_task from sqlalchemy import func +from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper @@ -48,104 +48,107 @@ def batch_create_segment_to_index_task( indexing_cache_key = f"segment_batch_import_{job_id}" - try: - dataset = db.session.get(Dataset, dataset_id) - if not dataset: - raise ValueError("Dataset not exist.") + with session_factory.create_session() as session: + try: + dataset = session.get(Dataset, dataset_id) + if not dataset: + raise ValueError("Dataset not exist.") - dataset_document = db.session.get(Document, document_id) - if not dataset_document: - raise ValueError("Document not exist.") + dataset_document = session.get(Document, document_id) + if not dataset_document: + raise ValueError("Document not exist.") - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - raise ValueError("Document is not available.") + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + raise ValueError("Document is not available.") - upload_file = db.session.get(UploadFile, upload_file_id) - if not upload_file: - raise ValueError("UploadFile not found.") + upload_file = session.get(UploadFile, upload_file_id) + if not upload_file: + raise ValueError("UploadFile not found.") - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore - storage.download(upload_file.key, file_path) + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file.key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file.key, file_path) - df = pd.read_csv(file_path) - content = [] - for _, row in df.iterrows(): + df = pd.read_csv(file_path) + content = [] + for _, row in df.iterrows(): + if dataset_document.doc_form == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} + else: + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") + + document_segments = [] + embedding_model = None + if dataset.indexing_technique == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + + word_count_change = 0 + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens( + texts=[segment["content"] for segment in content] + ) + else: + tokens_list = [0] * len(content) + + for segment, tokens in zip(content, tokens_list): + content = segment["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + max_position = ( + session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == dataset_document.id) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + created_by=user_id, + indexing_at=naive_utc_now(), + status="completed", + completed_at=naive_utc_now(), + ) if dataset_document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - content.append(data) - if len(content) == 0: - raise ValueError("The CSV file is empty.") + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count + session.add(segment_document) + document_segments.append(segment_document) - document_segments = [] - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) + assert dataset_document.word_count is not None + dataset_document.word_count += word_count_change + session.add(dataset_document) - word_count_change = 0 - if embedding_model: - tokens_list = embedding_model.get_text_embedding_num_tokens( - texts=[segment["content"] for segment in content] + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) + session.commit() + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + f"Segment batch created job: {job_id} latency: {end_at - start_at}", + fg="green", + ) ) - else: - tokens_list = [0] * len(content) - - for segment, tokens in zip(content, tokens_list): - content = segment["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == dataset_document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=tenant_id, - dataset_id=dataset_id, - document_id=document_id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - created_by=user_id, - indexing_at=naive_utc_now(), - status="completed", - completed_at=naive_utc_now(), - ) - if dataset_document.doc_form == "qa_model": - segment_document.answer = segment["answer"] - segment_document.word_count += len(segment["answer"]) - word_count_change += segment_document.word_count - db.session.add(segment_document) - document_segments.append(segment_document) - - assert dataset_document.word_count is not None - dataset_document.word_count += word_count_change - db.session.add(dataset_document) - - VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) - db.session.commit() - redis_client.setex(indexing_cache_key, 600, "completed") - end_at = time.perf_counter() - logger.info( - click.style( - f"Segment batch created job: {job_id} latency: {end_at - start_at}", - fg="green", - ) - ) - except Exception: - logger.exception("Segments batch created index failed") - redis_client.setex(indexing_cache_key, 600, "error") - finally: - db.session.close() + except Exception: + logger.exception("Segments batch created index failed") + redis_client.setex(indexing_cache_key, 600, "error") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index b4d82a150d..0d51a743ad 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from extensions.ext_database import db from extensions.ext_storage import storage from models import WorkflowType from models.dataset import ( @@ -53,135 +53,155 @@ def clean_dataset_task( logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green")) start_at = time.perf_counter() - try: - dataset = Dataset( - id=dataset_id, - tenant_id=tenant_id, - indexing_technique=indexing_technique, - index_struct=index_struct, - collection_binding_id=collection_binding_id, - ) - documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() - # Use JOIN to fetch attachments with bindings in a single query - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id) - ).all() - - # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace - # This ensures all invalid doc_form values are properly handled - if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): - # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup - from core.rag.index_processor.constant.index_type import IndexStructureType - - doc_form = IndexStructureType.PARAGRAPH_INDEX - logger.info( - click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow") - ) - - # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure - # This ensures Document/Segment deletion can continue even if vector database cleanup fails + with session_factory.create_session() as session: try: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) - logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green")) - except Exception: - logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red")) - # Continue with document and segment deletion even if vector cleanup fails - logger.info( - click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow") + dataset = Dataset( + id=dataset_id, + tenant_id=tenant_id, + indexing_technique=indexing_technique, + index_struct=index_struct, + collection_binding_id=collection_binding_id, ) + documents = session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + ) + ).all() - if documents is None or len(documents) == 0: - logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) - else: - logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) + # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace + # This ensures all invalid doc_form values are properly handled + if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): + # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup + from core.rag.index_processor.constant.index_type import IndexStructureType - for document in documents: - db.session.delete(document) - # delete document file + doc_form = IndexStructureType.PARAGRAPH_INDEX + logger.info( + click.style( + f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", + fg="yellow", + ) + ) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() - if image_file is None: - continue + # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure + # This ensures Document/Segment deletion can continue even if vector database cleanup fails + try: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) + logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green")) + except Exception: + logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red")) + # Continue with document and segment deletion even if vector cleanup fails + logger.info( + click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow") + ) + + if documents is None or len(documents) == 0: + logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) + else: + logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) + + for document in documents: + session.delete(document) + + segment_ids = [segment.id for segment in segments] + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() + for image_file in image_files: + if image_file is None: + continue + try: + storage.delete(image_file.key) + except Exception: + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file.id, + ) + stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(stmt) + + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + # delete segment attachments + if attachments_with_bindings: + attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] + binding_ids = [binding.id for binding, _ in attachments_with_bindings] + for binding, attachment_file in attachments_with_bindings: try: - storage.delete(image_file.key) + storage.delete(attachment_file.key) except Exception: logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - upload_file_id, + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, ) - db.session.delete(image_file) - db.session.delete(segment) - # delete segment attachments - if attachments_with_bindings: - for binding, attachment_file in attachments_with_bindings: - try: - storage.delete(attachment_file.key) - except Exception: - logger.exception( - "Delete attachment_file failed when storage deleted, \ - attachment_file_id: %s", - binding.attachment_id, - ) - db.session.delete(attachment_file) - db.session.delete(binding) + attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) + session.execute(attachment_file_delete_stmt) - db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() - db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() - db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() - # delete dataset metadata - db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() - db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() - # delete pipeline and workflow - if pipeline_id: - db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() - db.session.query(Workflow).where( - Workflow.tenant_id == tenant_id, - Workflow.app_id == pipeline_id, - Workflow.type == WorkflowType.RAG_PIPELINE, - ).delete() - # delete files - if documents: - for document in documents: - try: + binding_delete_stmt = delete(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.id.in_(binding_ids) + ) + session.execute(binding_delete_stmt) + + session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() + session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() + session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() + # delete dataset metadata + session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() + session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() + # delete pipeline and workflow + if pipeline_id: + session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() + session.query(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.app_id == pipeline_id, + Workflow.type == WorkflowType.RAG_PIPELINE, + ).delete() + # delete files + if documents: + file_ids = [] + for document in documents: if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - file = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) - .first() - ) - if not file: - continue - storage.delete(file.key) - db.session.delete(file) - except Exception: - continue + file_ids.append(file_id) + files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() + for file in files: + storage.delete(file.key) - db.session.commit() - end_at = time.perf_counter() - logger.info( - click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green") - ) - except Exception: - # Add rollback to prevent dirty session state in case of exceptions - # This ensures the database session is properly cleaned up - try: - db.session.rollback() - logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow")) + file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) + session.execute(file_delete_stmt) + + session.commit() + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", + fg="green", + ) + ) except Exception: - logger.exception("Failed to rollback database session") + # Add rollback to prevent dirty session state in case of exceptions + # This ensures the database session is properly cleaned up + try: + session.rollback() + logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow")) + except Exception: + logger.exception("Failed to rollback database session") - logger.exception("Cleaned dataset when dataset deleted failed") - finally: - db.session.close() + logger.exception("Cleaned dataset when dataset deleted failed") + finally: + # Explicitly close the session for test expectations and safety + try: + session.close() + except Exception: + logger.exception("Failed to close database session") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 6d2feb1da3..86e7cc7160 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding from models.model import UploadFile @@ -29,85 +29,94 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Document has no dataset") + if not dataset: + raise Exception("Document has no dataset") - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - # Use JOIN to fetch attachments with bindings in a single query - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where( - SegmentAttachmentBinding.tenant_id == dataset.tenant_id, - SegmentAttachmentBinding.dataset_id == dataset_id, - SegmentAttachmentBinding.document_id == document_id, - ) - ).all() - # check segment is exist - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == dataset.tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + SegmentAttachmentBinding.document_id == document_id, + ) + ).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - for upload_file_id in image_upload_file_ids: - image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() - if image_file is None: - continue + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + image_files = session.scalars( + select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + ).all() + for image_file in image_files: + if image_file is None: + continue + try: + storage.delete(image_file.key) + except Exception: + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file.id, + ) + + image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(image_file_delete_stmt) + session.delete(segment) + + session.commit() + if file_id: + file = session.query(UploadFile).where(UploadFile.id == file_id).first() + if file: try: - storage.delete(image_file.key) + storage.delete(file.key) + except Exception: + logger.exception("Delete file failed when document deleted, file_id: %s", file_id) + session.delete(file) + # delete segment attachments + if attachments_with_bindings: + attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] + binding_ids = [binding.id for binding, _ in attachments_with_bindings] + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) except Exception: logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - upload_file_id, + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, ) - db.session.delete(image_file) - db.session.delete(segment) + attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) + session.execute(attachment_file_delete_stmt) - db.session.commit() - if file_id: - file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() - if file: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file_id) - db.session.delete(file) - db.session.commit() - # delete segment attachments - if attachments_with_bindings: - for binding, attachment_file in attachments_with_bindings: - try: - storage.delete(attachment_file.key) - except Exception: - logger.exception( - "Delete attachment_file failed when storage deleted, \ - attachment_file_id: %s", - binding.attachment_id, - ) - db.session.delete(attachment_file) - db.session.delete(binding) + binding_delete_stmt = delete(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.id.in_(binding_ids) + ) + session.execute(binding_delete_stmt) - # delete dataset metadata binding - db.session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() - db.session.commit() + # delete dataset metadata binding + session.query(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ).delete() + session.commit() - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", - fg="green", + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", + fg="green", + ) ) - ) - except Exception: - logger.exception("Cleaned document when document deleted failed") - finally: - db.session.close() + except Exception: + logger.exception("Cleaned document when document deleted failed") diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 771b43f9b0..bcca1bf49f 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -3,10 +3,10 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment logger = logging.getLogger(__name__) @@ -24,37 +24,37 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Document has no dataset") - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - for document_id in document_ids: - document = db.session.query(Document).where(Document.id == document_id).first() - db.session.delete(document) + if not dataset: + raise Exception("Document has no dataset") + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) - ).all() - index_node_ids = [segment.index_node_id for segment in segments] + document_delete_stmt = delete(Document).where(Document.id.in_(document_ids)) + session.execute(document_delete_stmt) - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + for document_id in document_ids: + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() + index_node_ids = [segment.index_node_id for segment in segments] - for segment in segments: - db.session.delete(segment) - db.session.commit() - end_at = time.perf_counter() - logger.info( - click.style( - "Clean document when import form notion document deleted end :: {} latency: {}".format( - dataset_id, end_at - start_at - ), - fg="green", + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + session.commit() + end_at = time.perf_counter() + logger.info( + click.style( + "Clean document when import form notion document deleted end :: {} latency: {}".format( + dataset_id, end_at - start_at + ), + fg="green", + ) ) - ) - except Exception: - logger.exception("Cleaned document when import form notion document deleted failed") - finally: - db.session.close() + except Exception: + logger.exception("Cleaned document when import form notion document deleted failed") diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 6b2907cffd..b5e472d71e 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -4,9 +4,9 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment @@ -25,75 +25,77 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() - if not segment: - logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) - db.session.close() - return - - if segment.status != "waiting": - db.session.close() - return - - indexing_cache_key = f"segment_{segment.id}_indexing" - - try: - # update segment status to indexing - db.session.query(DocumentSegment).filter_by(id=segment.id).update( - { - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: naive_utc_now(), - } - ) - db.session.commit() - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - - dataset = segment.dataset - - if not dataset: - logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + with session_factory.create_session() as session: + segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + if not segment: + logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - dataset_document = segment.document - - if not dataset_document: - logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + if segment.status != "waiting": return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) - return + indexing_cache_key = f"segment_{segment.id}_indexing" - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, [document]) + try: + # update segment status to indexing + session.query(DocumentSegment).filter_by(id=segment.id).update( + { + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: naive_utc_now(), + } + ) + session.commit() + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) - # update segment to completed - db.session.query(DocumentSegment).filter_by(id=segment.id).update( - { - DocumentSegment.status: "completed", - DocumentSegment.completed_at: naive_utc_now(), - } - ) - db.session.commit() + dataset = segment.dataset - end_at = time.perf_counter() - logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("create segment to index failed") - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.status = "error" - segment.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + if not dataset: + logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + return + + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) + return + + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.load(dataset, [document]) + + # update segment to completed + session.query(DocumentSegment).filter_by(id=segment.id).update( + { + DocumentSegment.status: "completed", + DocumentSegment.completed_at: naive_utc_now(), + } + ) + session.commit() + + end_at = time.perf_counter() + logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green")) + except Exception as e: + logger.exception("create segment to index failed") + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.status = "error" + segment.error = str(e) + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index 3d13afdec0..fa844a8647 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -4,11 +4,11 @@ import time import click from celery import shared_task # type: ignore +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -24,166 +24,174 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset: - raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX - index_processor = IndexProcessorFactory(index_type).init_index_processor() - if action == "upgrade": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, + if not dataset: + raise Exception("Dataset not found") + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if action == "upgrade": + dataset_documents = ( + session.query(DatasetDocument) + .where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() ) - .all() - ) - if dataset_documents: - dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False - ) - db.session.commit() + if dataset_documents: + dataset_documents_ids = [doc.id for doc in dataset_documents] + session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + session.commit() - for dataset_document in dataset_documents: - try: - # add from vector index - segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) - .order_by(DocumentSegment.position.asc()) - .all() - ) - if segments: - documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, + for dataset_document in dataset_documents: + try: + # add from vector index + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, ) - - documents.append(document) - # save vector index - # clean keywords - index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) - index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False - ) - db.session.commit() - except Exception as e: - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False - ) - db.session.commit() - elif action == "update": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - .all() - ) - # add new index - if dataset_documents: - # update document status - dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False - ) - db.session.commit() - - # clean index - index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) - - for dataset_document in dataset_documents: - # update from vector index - try: - segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) - .order_by(DocumentSegment.position.asc()) - .all() - ) - if segments: - documents = [] - multimodal_documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - child_documents.append(child_document) - document.children = child_documents - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodal_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - documents.append(document) - # save vector index - index_processor.load( - dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + .order_by(DocumentSegment.position.asc()) + .all() ) - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False - ) - db.session.commit() - except Exception as e: - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False - ) - db.session.commit() - else: - # clean collection - index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) - end_at = time.perf_counter() - logging.info( - click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") - ) - except Exception: - logging.exception("Deal dataset vector index failed") - finally: - db.session.close() + documents.append(document) + # save vector index + # clean keywords + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) + index_processor.load(dataset, documents, with_keywords=False) + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + session.commit() + except Exception as e: + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + session.commit() + elif action == "update": + dataset_documents = ( + session.query(DatasetDocument) + .where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) + # add new index + if dataset_documents: + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + for dataset_document in dataset_documents: + # update from vector index + try: + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + multimodal_documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + documents.append(document) + # save vector index + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + session.commit() + except Exception as e: + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info( + click.style( + "Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Deal dataset vector index failed") diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 1c7de3b1ce..0047e04a17 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -5,11 +5,11 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -27,160 +27,170 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() - if not dataset: - raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX - index_processor = IndexProcessorFactory(index_type).init_index_processor() - if action == "remove": - index_processor.clean(dataset, None, with_keywords=False) - elif action == "add": - dataset_documents = db.session.scalars( - select(DatasetDocument).where( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - ).all() + if not dataset: + raise Exception("Dataset not found") + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if action == "remove": + index_processor.clean(dataset, None, with_keywords=False) + elif action == "add": + dataset_documents = session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() - if dataset_documents: - dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False - ) - db.session.commit() + if dataset_documents: + dataset_documents_ids = [doc.id for doc in dataset_documents] + session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + session.commit() - for dataset_document in dataset_documents: - try: - # add from vector index - segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) - .order_by(DocumentSegment.position.asc()) - .all() - ) - if segments: - documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, + for dataset_document in dataset_documents: + try: + # add from vector index + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, ) - - documents.append(document) - # save vector index - index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False - ) - db.session.commit() - except Exception as e: - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False - ) - db.session.commit() - elif action == "update": - dataset_documents = db.session.scalars( - select(DatasetDocument).where( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - ).all() - # add new index - if dataset_documents: - # update document status - dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False - ) - db.session.commit() - - # clean index - index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) - - for dataset_document in dataset_documents: - # update from vector index - try: - segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) - .order_by(DocumentSegment.position.asc()) - .all() - ) - if segments: - documents = [] - multimodal_documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - child_documents.append(child_document) - document.children = child_documents - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodal_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - documents.append(document) - # save vector index - index_processor.load( - dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + .order_by(DocumentSegment.position.asc()) + .all() ) - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False - ) - db.session.commit() - except Exception as e: - db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False - ) - db.session.commit() - else: - # clean collection - index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + if segments: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) - end_at = time.perf_counter() - logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green")) - except Exception: - logger.exception("Deal dataset vector index failed") - finally: - db.session.close() + documents.append(document) + # save vector index + index_processor.load(dataset, documents, with_keywords=False) + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + session.commit() + except Exception as e: + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + session.commit() + elif action == "update": + dataset_documents = session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() + # add new index + if dataset_documents: + # update document status + dataset_documents_ids = [doc.id for doc in dataset_documents] + session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) + session.commit() + + # clean index + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + for dataset_document in dataset_documents: + # update from vector index + try: + segments = ( + session.query(DocumentSegment) + .where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == True, + ) + .order_by(DocumentSegment.position.asc()) + .all() + ) + if segments: + documents = [] + multimodal_documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + documents.append(document) + # save vector index + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) + session.commit() + except Exception as e: + session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) + session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("Deal dataset vector index failed") diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index cb703cc263..ecf6f9cb39 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -3,7 +3,7 @@ import logging from celery import shared_task from configs import dify_config -from extensions.ext_database import db +from core.db.session_factory import session_factory from models import Account from services.billing_service import BillingService from tasks.mail_account_deletion_task import send_deletion_success_task @@ -13,16 +13,17 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_account_task(account_id): - account = db.session.query(Account).where(Account.id == account_id).first() - try: - if dify_config.BILLING_ENABLED: - BillingService.delete_account(account_id) - except Exception: - logger.exception("Failed to delete account %s from billing service.", account_id) - raise + with session_factory.create_session() as session: + account = session.query(Account).where(Account.id == account_id).first() + try: + if dify_config.BILLING_ENABLED: + BillingService.delete_account(account_id) + except Exception: + logger.exception("Failed to delete account %s from billing service.", account_id) + raise - if not account: - logger.error("Account %s not found.", account_id) - return - # send success email - send_deletion_success_task.delay(account.email) + if not account: + logger.error("Account %s not found.", account_id) + return + # send success email + send_deletion_success_task.delay(account.email) diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py index 756b67c93e..9664b8ac73 100644 --- a/api/tasks/delete_conversation_task.py +++ b/api/tasks/delete_conversation_task.py @@ -4,7 +4,7 @@ import time import click from celery import shared_task -from extensions.ext_database import db +from core.db.session_factory import session_factory from models import ConversationVariable from models.model import Message, MessageAnnotation, MessageFeedback from models.tools import ToolConversationVariables, ToolFile @@ -27,44 +27,46 @@ def delete_conversation_related_data(conversation_id: str): ) start_at = time.perf_counter() - try: - db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - db.session.query(ToolConversationVariables).where( - ToolConversationVariables.conversation_id == conversation_id - ).delete(synchronize_session=False) - - db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False) - - db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False) - - db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - db.session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}", - fg="green", + with session_factory.create_session() as session: + try: + session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete( + synchronize_session=False ) - ) - except Exception as e: - logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id) - db.session.rollback() - raise e - finally: - db.session.close() + session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + session.query(ToolConversationVariables).where( + ToolConversationVariables.conversation_id == conversation_id + ).delete(synchronize_session=False) + + session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False) + + session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False) + + session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + session.commit() + + end_at = time.perf_counter() + logger.info( + click.style( + ( + f"Succeeded cleaning data from db for conversation_id {conversation_id} " + f"latency: {end_at - start_at}" + ), + fg="green", + ) + ) + + except Exception: + logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id) + session.rollback() + raise diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index bea5c952cf..bfa709502c 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -4,8 +4,8 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from models.dataset import Dataset, Document, SegmentAttachmentBinding from models.model import UploadFile @@ -26,49 +26,52 @@ def delete_segment_from_index_task( """ logger.info(click.style("Start delete segment from index", fg="green")) start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) - return + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) + return - dataset_document = db.session.query(Document).where(Document.id == document_id).first() - if not dataset_document: - return + dataset_document = session.query(Document).where(Document.id == document_id).first() + if not dataset_document: + return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info("Document not in valid state for index operations, skipping") - return - doc_form = dataset_document.doc_form + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + logging.info("Document not in valid state for index operations, skipping") + return + doc_form = dataset_document.doc_form - # Proceed with index cleanup using the index_node_ids directly - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean( - dataset, - index_node_ids, - with_keywords=True, - delete_child_chunks=True, - precomputed_child_node_ids=child_node_ids, - ) - if dataset.is_multimodal: - # delete segment attachment binding - segment_attachment_bindings = ( - db.session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() + # Proceed with index cleanup using the index_node_ids directly + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, + index_node_ids, + with_keywords=True, + delete_child_chunks=True, + precomputed_child_node_ids=child_node_ids, ) - if segment_attachment_bindings: - attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] - index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) - for binding in segment_attachment_bindings: - db.session.delete(binding) - # delete upload file - db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) - db.session.commit() + if dataset.is_multimodal: + # delete segment attachment binding + segment_attachment_bindings = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) + for binding in segment_attachment_bindings: + session.delete(binding) + # delete upload file + session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + session.commit() - end_at = time.perf_counter() - logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) - except Exception: - logger.exception("delete segment from index failed") - finally: - db.session.close() + end_at = time.perf_counter() + logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) + except Exception: + logger.exception("delete segment from index failed") diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 6b5f01b416..0ce6429a94 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -4,8 +4,8 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -23,46 +23,53 @@ def disable_segment_from_index_task(segment_id: str): logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() - if not segment: - logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) - db.session.close() - return - - if segment.status != "completed": - logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red")) - db.session.close() - return - - indexing_cache_key = f"segment_{segment.id}_indexing" - - try: - dataset = segment.dataset - - if not dataset: - logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + with session_factory.create_session() as session: + segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + if not segment: + logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - dataset_document = segment.document - - if not dataset_document: - logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + if segment.status != "completed": + logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) - return + indexing_cache_key = f"segment_{segment.id}_indexing" - index_type = dataset_document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, [segment.index_node_id]) + try: + dataset = segment.dataset - end_at = time.perf_counter() - logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green")) - except Exception: - logger.exception("remove segment from index failed") - segment.enabled = True - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + if not dataset: + logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + return + + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) + return + + index_type = dataset_document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_processor.clean(dataset, [segment.index_node_id]) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Segment removed from index: {segment.id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("remove segment from index failed") + segment.enabled = True + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index c2a3de29f4..03635902d1 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -5,8 +5,8 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument @@ -26,69 +26,65 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) - db.session.close() - return + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) + return - dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() - if not dataset_document: - logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) - db.session.close() - return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) - db.session.close() - return - # sync index processor - index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if not dataset_document: + logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ) - ).all() - - if not segments: - db.session.close() - return - - try: - index_node_ids = [segment.index_node_id for segment in segments] - if dataset.is_multimodal: - segment_ids = [segment.id for segment in segments] - segment_attachment_bindings = ( - db.session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() + segments = session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, ) - if segment_attachment_bindings: - attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] - index_node_ids.extend(attachment_ids) - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + ).all() - end_at = time.perf_counter() - logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) - except Exception: - # update segment error msg - db.session.query(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ).update( - { - "disabled_at": None, - "disabled_by": None, - "enabled": True, - } - ) - db.session.commit() - finally: - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - redis_client.delete(indexing_cache_key) - db.session.close() + if not segments: + return + + try: + index_node_ids = [segment.index_node_id for segment in segments] + if dataset.is_multimodal: + segment_ids = [segment.id for segment in segments] + segment_attachment_bindings = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_node_ids.extend(attachment_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + + end_at = time.perf_counter() + logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) + except Exception: + # update segment error msg + session.query(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "disabled_at": None, + "disabled_by": None, + "enabled": True, + } + ) + session.commit() + finally: + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 5fc2597c92..149185f6e2 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,12 +3,12 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.datasource_provider_service import DatasourceProviderService @@ -28,105 +28,103 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Start sync document: {document_id}", fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + with session_factory.create_session() as session: + document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="red")) - db.session.close() - return - - data_source_info = document.data_source_info_dict - if document.data_source_type == "notion_import": - if ( - not data_source_info - or "notion_page_id" not in data_source_info - or "notion_workspace_id" not in data_source_info - ): - raise ValueError("no notion page found") - workspace_id = data_source_info["notion_workspace_id"] - page_id = data_source_info["notion_page_id"] - page_type = data_source_info["type"] - page_edited_time = data_source_info["last_edited_time"] - credential_id = data_source_info.get("credential_id") - - # Get credentials from datasource provider - datasource_provider_service = DatasourceProviderService() - credential = datasource_provider_service.get_datasource_credentials( - tenant_id=document.tenant_id, - credential_id=credential_id, - provider="notion_datasource", - plugin_id="langgenius/notion_datasource", - ) - - if not credential: - logger.error( - "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", - document_id, - document.tenant_id, - credential_id, - ) - document.indexing_status = "error" - document.error = "Datasource credential not found. Please reconnect your Notion workspace." - document.stopped_at = naive_utc_now() - db.session.commit() - db.session.close() + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="red")) return - loader = NotionExtractor( - notion_workspace_id=workspace_id, - notion_obj_id=page_id, - notion_page_type=page_type, - notion_access_token=credential.get("integration_secret"), - tenant_id=document.tenant_id, - ) + data_source_info = document.data_source_info_dict + if document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_page_id" not in data_source_info + or "notion_workspace_id" not in data_source_info + ): + raise ValueError("no notion page found") + workspace_id = data_source_info["notion_workspace_id"] + page_id = data_source_info["notion_page_id"] + page_type = data_source_info["type"] + page_edited_time = data_source_info["last_edited_time"] + credential_id = data_source_info.get("credential_id") - last_edited_time = loader.get_notion_last_edited_time() + # Get credentials from datasource provider + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_datasource_credentials( + tenant_id=document.tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) - # check the page is updated - if last_edited_time != page_edited_time: - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - db.session.commit() - - # delete all document segment and index - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) - ).all() - index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - - end_at = time.perf_counter() - logger.info( - click.style( - "Cleaned document when document update data source or process rule: {} latency: {}".format( - document_id, end_at - start_at - ), - fg="green", - ) + if not credential: + logger.error( + "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", + document_id, + document.tenant_id, + credential_id, ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + document.indexing_status = "error" + document.error = "Datasource credential not found. Please reconnect your Notion workspace." + document.stopped_at = naive_utc_now() + session.commit() + return - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_sync_task failed, document_id: %s", document_id) - finally: - db.session.close() + loader = NotionExtractor( + notion_workspace_id=workspace_id, + notion_obj_id=page_id, + notion_page_type=page_type, + notion_access_token=credential.get("integration_secret"), + tenant_id=document.tenant_id, + ) + + last_edited_time = loader.get_notion_last_edited_time() + + # check the page is updated + if last_edited_time != page_edited_time: + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + session.commit() + + # delete all document segment and index + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + raise Exception("Dataset not found") + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() + + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + + end_at = time.perf_counter() + logger.info( + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) + except Exception: + logger.exception("Cleaned document when document update data source or process rule failed") + + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("document_indexing_sync_task failed, document_id: %s", document_id) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index acbdab631b..3bdff60196 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -6,11 +6,11 @@ import click from celery import shared_task from configs import dify_config +from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -46,66 +46,63 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): documents = [] start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) - db.session.close() - return - # check document limit - features = FeatureService.get_features(dataset.tenant_id) - try: - if features.billing.enabled: - vector_space = features.vector_space - count = len(document_ids) - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - if 0 < vector_space.limit <= vector_space.size: - raise ValueError( - "Your total number of documents plus the number of uploads have over the limit of " - "your subscription." + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) + return + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + count = len(document_ids) + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + for document_id in document_ids: + document = ( + session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) - except Exception as e: + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + session.add(document) + session.commit() + return + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) + document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) + if document: - document.indexing_status = "error" - document.error = str(e) - document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - db.session.close() - return + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + documents.append(document) + session.add(document) + session.commit() - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - - if document: - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() - - try: - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) - finally: - db.session.close() + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) def _document_indexing_with_tenant_queue( diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 161502a228..67a23be952 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,8 +3,9 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -26,56 +27,54 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logger.info(click.style(f"Start update document: {document_id}", fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + with session_factory.create_session() as session: + document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="red")) - db.session.close() - return + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="red")) + return - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - db.session.commit() + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + session.commit() - # delete all document segment and index - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") + # delete all document segment and index + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + raise Exception("Dataset not found") - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_type = document.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - end_at = time.perf_counter() - logger.info( - click.style( - "Cleaned document when document update data source or process rule: {} latency: {}".format( - document_id, end_at - start_at - ), - fg="green", + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + db.session.commit() + end_at = time.perf_counter() + logger.info( + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) ) - ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + except Exception: + logger.exception("Cleaned document when document update data source or process rule failed") - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_update_task failed, document_id: %s", document_id) - finally: - db.session.close() + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("document_indexing_update_task failed, document_id: %s", document_id) diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 4078c8910e..00a963255b 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -4,15 +4,15 @@ from collections.abc import Callable, Sequence import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select from configs import dify_config +from core.db.session_factory import session_factory from core.entities.document_task import DocumentTask from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService @@ -76,63 +76,64 @@ def _duplicate_document_indexing_task_with_tenant_queue( def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]): - documents = [] + documents: list[Document] = [] start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if dataset is None: - logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) - db.session.close() - return - - # check document limit - features = FeatureService.get_features(dataset.tenant_id) + with session_factory.create_session() as session: try: - if features.billing.enabled: - vector_space = features.vector_space - count = len(document_ids) - if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: - raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - current = int(getattr(vector_space, "size", 0) or 0) - limit = int(getattr(vector_space, "limit", 0) or 0) - if limit > 0 and (current + count) > limit: - raise ValueError( - "Your total number of documents plus the number of uploads have exceeded the limit of " - "your subscription." - ) - except Exception as e: - for document_id in document_ids: - document = ( - db.session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset is None: + logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + count = len(document_ids) + if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1: + raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + current = int(getattr(vector_space, "size", 0) or 0) + limit = int(getattr(vector_space, "limit", 0) or 0) + if limit > 0 and (current + count) > limit: + raise ValueError( + "Your total number of documents plus the number of uploads have exceeded the limit of " + "your subscription." + ) + except Exception as e: + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) - if document: - document.indexing_status = "error" - document.error = str(e) - document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - return + for document in documents: + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + session.add(document) + session.commit() + return - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) - if document: + for document in documents: + logger.info(click.style(f"Start process document: {document.id}", fg="green")) + # clean old data index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id) ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -140,26 +141,24 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st # delete from vector index index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + session.commit() document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) - finally: - db.session.close() + indexing_runner = IndexingRunner() + indexing_runner.run(list(documents)) + end_at = time.perf_counter() + logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) @shared_task(queue="dataset") diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 7615469ed0..1f9f21aa7e 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,11 +4,11 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment @@ -27,91 +27,93 @@ def enable_segment_to_index_task(segment_id: str): logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green")) start_at = time.perf_counter() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() - if not segment: - logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) - db.session.close() - return - - if segment.status != "completed": - logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red")) - db.session.close() - return - - indexing_cache_key = f"segment_{segment.id}_indexing" - - try: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - - dataset = segment.dataset - - if not dataset: - logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + with session_factory.create_session() as session: + segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + if not segment: + logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - dataset_document = segment.document - - if not dataset_document: - logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + if segment.status != "completed": + logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) - return + indexing_cache_key = f"segment_{segment.id}_indexing" - index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, + try: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + + dataset = segment.dataset + + if not dataset: + logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) + return + + dataset_document = segment.document + + if not dataset_document: + logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) + return + + if ( + not dataset_document.enabled + or dataset_document.archived + or dataset_document.indexing_status != "completed" + ): + logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) + return + + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + multimodel_documents = [] + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodel_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) ) - child_documents.append(child_document) - document.children = child_documents - multimodel_documents = [] - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodel_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - # save vector index - index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) + # save vector index + index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) - end_at = time.perf_counter() - logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("enable segment to index failed") - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.status = "error" - segment.error = str(e) - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + end_at = time.perf_counter() + logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) + except Exception as e: + logger.exception("enable segment to index failed") + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.status = "error" + segment.error = str(e) + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 9f17d09e18..48d3c8e178 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -5,11 +5,11 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DocumentSegment @@ -29,105 +29,102 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id) """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) - return + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) + return - dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() - if not dataset_document: - logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) - db.session.close() - return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) - db.session.close() - return - # sync index processor - index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if not dataset_document: + logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ) - ).all() - if not segments: - logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) - db.session.close() - return - - try: - documents = [] - multimodal_documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": document_id, - "dataset_id": dataset_id, - }, + segments = session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, ) + ).all() + if not segments: + logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) + return - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": document_id, - "dataset_id": dataset_id, - }, + try: + documents = [] + multimodal_documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) ) - child_documents.append(child_document) - document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) - if dataset.is_multimodal: - for attachment in segment.attachments: - multimodal_documents.append( - AttachmentDocument( - page_content=attachment["name"], - metadata={ - "doc_id": attachment["id"], - "doc_hash": "", - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - "doc_type": DocType.IMAGE, - }, - ) - ) - documents.append(document) - # save vector index - index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) - - end_at = time.perf_counter() - logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception("enable segments to index failed") - # update segment error msg - db.session.query(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ).update( - { - "error": str(e), - "status": "error", - "disabled_at": naive_utc_now(), - "enabled": False, - } - ) - db.session.commit() - finally: - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - redis_client.delete(indexing_cache_key) - db.session.close() + end_at = time.perf_counter() + logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) + except Exception as e: + logger.exception("enable segments to index failed") + # update segment error msg + session.query(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "error": str(e), + "status": "error", + "disabled_at": naive_utc_now(), + "enabled": False, + } + ) + session.commit() + finally: + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 1b2a653c01..af72023da1 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -4,8 +4,8 @@ import time import click from celery import shared_task +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from extensions.ext_database import db from models.dataset import Document logger = logging.getLogger(__name__) @@ -23,26 +23,24 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): logger.info(click.style(f"Recover document: {document_id}", fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + with session_factory.create_session() as session: + document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="red")) - db.session.close() - return + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="red")) + return - try: - indexing_runner = IndexingRunner() - if document.indexing_status in {"waiting", "parsing", "cleaning"}: - indexing_runner.run([document]) - elif document.indexing_status == "splitting": - indexing_runner.run_in_splitting_status(document) - elif document.indexing_status == "indexing": - indexing_runner.run_in_indexing_status(document) - end_at = time.perf_counter() - logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("recover_document_indexing_task failed, document_id: %s", document_id) - finally: - db.session.close() + try: + indexing_runner = IndexingRunner() + if document.indexing_status in {"waiting", "parsing", "cleaning"}: + indexing_runner.run([document]) + elif document.indexing_status == "splitting": + indexing_runner.run_in_splitting_status(document) + elif document.indexing_status == "indexing": + indexing_runner.run_in_indexing_status(document) + end_at = time.perf_counter() + logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("recover_document_indexing_task failed, document_id: %s", document_id) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 3227f6da96..817249845a 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -1,15 +1,20 @@ import logging import time from collections.abc import Callable +from typing import Any, cast import click import sqlalchemy as sa from celery import shared_task from sqlalchemy import delete +from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker +from configs import dify_config +from core.db.session_factory import session_factory from extensions.ext_database import db +from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from models import ( ApiToken, AppAnnotationHitHistory, @@ -40,6 +45,7 @@ from models.workflow import ( ConversationVariable, Workflow, WorkflowAppLog, + WorkflowArchiveLog, ) from repositories.factory import DifyAPIRepositoryFactory @@ -64,6 +70,9 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_app_workflow_runs(tenant_id, app_id) _delete_app_workflow_node_executions(tenant_id, app_id) _delete_app_workflow_app_logs(tenant_id, app_id) + if dify_config.BILLING_ENABLED and dify_config.ARCHIVE_STORAGE_ENABLED: + _delete_app_workflow_archive_logs(tenant_id, app_id) + _delete_archived_workflow_run_files(tenant_id, app_id) _delete_app_conversations(tenant_id, app_id) _delete_app_messages(tenant_id, app_id) _delete_workflow_tool_providers(tenant_id, app_id) @@ -77,7 +86,6 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_workflow_webhook_triggers(tenant_id, app_id) _delete_workflow_schedule_plans(tenant_id, app_id) _delete_workflow_trigger_logs(tenant_id, app_id) - end_at = time.perf_counter() logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) except SQLAlchemyError as e: @@ -89,8 +97,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str): - def del_model_config(model_config_id: str): - db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + def del_model_config(session, model_config_id: str): + session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) _delete_records( """select id from app_model_configs where app_id=:app_id limit 1000""", @@ -101,8 +109,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str): - def del_site(site_id: str): - db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) + def del_site(session, site_id: str): + session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) _delete_records( """select id from sites where app_id=:app_id limit 1000""", @@ -113,8 +121,8 @@ def _delete_app_site(tenant_id: str, app_id: str): def _delete_app_mcp_servers(tenant_id: str, app_id: str): - def del_mcp_server(mcp_server_id: str): - db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + def del_mcp_server(session, mcp_server_id: str): + session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) _delete_records( """select id from app_mcp_servers where app_id=:app_id limit 1000""", @@ -125,8 +133,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): - def del_api_token(api_token_id: str): - db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) + def del_api_token(session, api_token_id: str): + session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( """select id from api_tokens where app_id=:app_id limit 1000""", @@ -137,8 +145,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str): - def del_installed_app(installed_app_id: str): - db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + def del_installed_app(session, installed_app_id: str): + session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) _delete_records( """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -149,10 +157,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str): - def del_recommended_app(recommended_app_id: str): - db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete( - synchronize_session=False - ) + def del_recommended_app(session, recommended_app_id: str): + session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", @@ -163,8 +169,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str): - def del_annotation_hit_history(annotation_hit_history_id: str): - db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( + def del_annotation_hit_history(session, annotation_hit_history_id: str): + session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( synchronize_session=False ) @@ -175,8 +181,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): "annotation hit history", ) - def del_annotation_setting(annotation_setting_id: str): - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( + def del_annotation_setting(session, annotation_setting_id: str): + session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( synchronize_session=False ) @@ -189,8 +195,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str): - def del_dataset_join(dataset_join_id: str): - db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + def del_dataset_join(session, dataset_join_id: str): + session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) _delete_records( """select id from app_dataset_joins where app_id=:app_id limit 1000""", @@ -201,8 +207,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str): - def del_workflow(workflow_id: str): - db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) + def del_workflow(session, workflow_id: str): + session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) _delete_records( """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -241,10 +247,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): - def del_workflow_app_log(workflow_app_log_id: str): - db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete( - synchronize_session=False - ) + def del_workflow_app_log(session, workflow_app_log_id: str): + session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -254,12 +258,51 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): ) -def _delete_app_conversations(tenant_id: str, app_id: str): - def del_conversation(conversation_id: str): - db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( +def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): + def del_workflow_archive_log(workflow_archive_log_id: str): + db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( synchronize_session=False ) - db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) + + _delete_records( + """select id from workflow_archive_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow_archive_log, + "workflow archive log", + ) + + +def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): + prefix = f"{tenant_id}/app_id={app_id}/" + try: + archive_storage = get_archive_storage() + except ArchiveStorageNotConfiguredError as e: + logger.info("Archive storage not configured, skipping archive file cleanup: %s", e) + return + + try: + keys = archive_storage.list_objects(prefix) + except Exception: + logger.exception("Failed to list archive files for app %s", app_id) + return + + deleted = 0 + for key in keys: + try: + archive_storage.delete_object(key) + deleted += 1 + except Exception: + logger.exception("Failed to delete archive object %s", key) + + logger.info("Deleted %s archive objects for app %s", deleted, app_id) + + +def _delete_app_conversations(tenant_id: str, app_id: str): + def del_conversation(session, conversation_id: str): + session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", @@ -270,28 +313,26 @@ def _delete_app_conversations(tenant_id: str, app_id: str): def _delete_conversation_variables(*, app_id: str): - stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) - with db.engine.connect() as conn: - conn.execute(stmt) - conn.commit() + with session_factory.create_session() as session: + stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) + session.execute(stmt) + session.commit() logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) def _delete_app_messages(tenant_id: str, app_id: str): - def del_message(message_id: str): - db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete( + def del_message(session, message_id: str): + session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False) + session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( + session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) + session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( synchronize_session=False ) - db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) - db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False - ) - db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) - db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) - db.session.query(Message).where(Message.id == message_id).delete() + session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) + session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) + session.query(Message).where(Message.id == message_id).delete() _delete_records( """select id from messages where app_id=:app_id limit 1000""", @@ -302,8 +343,8 @@ def _delete_app_messages(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str): - def del_tool_provider(tool_provider_id: str): - db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( + def del_tool_provider(session, tool_provider_id: str): + session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( synchronize_session=False ) @@ -316,8 +357,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str): - def del_tag_binding(tag_binding_id: str): - db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + def del_tag_binding(session, tag_binding_id: str): + session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) _delete_records( """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", @@ -328,8 +369,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str): - def del_end_user(end_user_id: str): - db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) + def del_end_user(session, end_user_id: str): + session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) _delete_records( """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -340,10 +381,8 @@ def _delete_end_users(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str): - def del_trace_app_config(trace_app_config_id: str): - db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete( - synchronize_session=False - ) + def del_trace_app_config(session, trace_app_config_id: str): + session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", @@ -381,14 +420,14 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: total_files_deleted = 0 while True: - with db.engine.begin() as conn: + with session_factory.create_session() as session: # Get a batch of draft variable IDs along with their file_ids query_sql = """ SELECT id, file_id FROM workflow_draft_variables WHERE app_id = :app_id LIMIT :batch_size """ - result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size}) + result = session.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size}) rows = list(result) if not rows: @@ -399,7 +438,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: # Clean up associated Offload data first if file_ids: - files_deleted = _delete_draft_variable_offload_data(conn, file_ids) + files_deleted = _delete_draft_variable_offload_data(session, file_ids) total_files_deleted += files_deleted # Delete the draft variables @@ -407,8 +446,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: DELETE FROM workflow_draft_variables WHERE id IN :ids """ - deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}) - batch_deleted = deleted_result.rowcount + deleted_result = cast( + CursorResult[Any], + session.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}), + ) + batch_deleted: int = int(getattr(deleted_result, "rowcount", 0) or 0) total_deleted += batch_deleted logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green")) @@ -423,7 +465,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: return total_deleted -def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: +def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int: """ Delete Offload data associated with WorkflowDraftVariable file_ids. @@ -434,7 +476,7 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: 4. Deletes WorkflowDraftVariableFile records Args: - conn: Database connection + session: Database connection file_ids: List of WorkflowDraftVariableFile IDs Returns: @@ -450,12 +492,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: try: # Get WorkflowDraftVariableFile records and their associated UploadFile keys query_sql = """ - SELECT wdvf.id, uf.key, uf.id as upload_file_id - FROM workflow_draft_variable_files wdvf - JOIN upload_files uf ON wdvf.upload_file_id = uf.id - WHERE wdvf.id IN :file_ids - """ - result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)}) + SELECT wdvf.id, uf.key, uf.id as upload_file_id + FROM workflow_draft_variable_files wdvf + JOIN upload_files uf ON wdvf.upload_file_id = uf.id + WHERE wdvf.id IN :file_ids \ + """ + result = session.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)}) file_records = list(result) # Delete from object storage and collect upload file IDs @@ -473,17 +515,19 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: # Delete UploadFile records if upload_file_ids: delete_upload_files_sql = """ - DELETE FROM upload_files - WHERE id IN :upload_file_ids - """ - conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)}) + DELETE \ + FROM upload_files + WHERE id IN :upload_file_ids \ + """ + session.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)}) # Delete WorkflowDraftVariableFile records delete_variable_files_sql = """ - DELETE FROM workflow_draft_variable_files - WHERE id IN :file_ids - """ - conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)}) + DELETE \ + FROM workflow_draft_variable_files + WHERE id IN :file_ids \ + """ + session.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)}) except Exception: logging.exception("Error deleting draft variable offload data:") @@ -493,8 +537,8 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: def _delete_app_triggers(tenant_id: str, app_id: str): - def del_app_trigger(trigger_id: str): - db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) + def del_app_trigger(session, trigger_id: str): + session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) _delete_records( """select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -505,8 +549,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str): def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): - def del_plugin_trigger(trigger_id: str): - db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( + def del_plugin_trigger(session, trigger_id: str): + session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( synchronize_session=False ) @@ -519,8 +563,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): - def del_webhook_trigger(trigger_id: str): - db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( + def del_webhook_trigger(session, trigger_id: str): + session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( synchronize_session=False ) @@ -533,10 +577,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): - def del_schedule_plan(plan_id: str): - db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete( - synchronize_session=False - ) + def del_schedule_plan(session, plan_id: str): + session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False) _delete_records( """select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -547,8 +589,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): - def del_trigger_log(log_id: str): - db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) + def del_trigger_log(session, log_id: str): + session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) _delete_records( """select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -560,18 +602,22 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query_sql), params) - if rs.rowcount == 0: + with session_factory.create_session() as session: + rs = session.execute(sa.text(query_sql), params) + rows = rs.fetchall() + if not rows: break - for i in rs: + for i in rows: record_id = str(i.id) try: - delete_func(record_id) - db.session.commit() + delete_func(session, record_id) logger.info(click.style(f"Deleted {name} {record_id}", fg="green")) except Exception: logger.exception("Error occurred while deleting %s %s", name, record_id) - continue + # continue with next record even if one deletion fails + session.rollback() + break + session.commit() + rs.close() diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index c0ab2d0b41..c3c255fb17 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -5,8 +5,8 @@ import click from celery import shared_task from sqlalchemy import select +from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Document, DocumentSegment @@ -25,52 +25,55 @@ def remove_document_from_index_task(document_id: str): logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).where(Document.id == document_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="red")) - db.session.close() - return + with session_factory.create_session() as session: + document = session.query(Document).where(Document.id == document_id).first() + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="red")) + return - if document.indexing_status != "completed": - logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red")) - db.session.close() - return + if document.indexing_status != "completed": + logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red")) + return - indexing_cache_key = f"document_{document.id}_indexing" + indexing_cache_key = f"document_{document.id}_indexing" - try: - dataset = document.dataset + try: + dataset = document.dataset - if not dataset: - raise Exception("Document has no dataset") + if not dataset: + raise Exception("Document has no dataset") - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() - index_node_ids = [segment.index_node_id for segment in segments] - if index_node_ids: - try: - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) - except Exception: - logger.exception("clean dataset %s from index failed", dataset.id) - # update segment to disable - db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( - { - DocumentSegment.enabled: False, - DocumentSegment.disabled_at: naive_utc_now(), - DocumentSegment.disabled_by: document.disabled_by, - DocumentSegment.updated_at: naive_utc_now(), - } - ) - db.session.commit() + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() + index_node_ids = [segment.index_node_id for segment in segments] + if index_node_ids: + try: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + except Exception: + logger.exception("clean dataset %s from index failed", dataset.id) + # update segment to disable + session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( + { + DocumentSegment.enabled: False, + DocumentSegment.disabled_at: naive_utc_now(), + DocumentSegment.disabled_by: document.disabled_by, + DocumentSegment.updated_at: naive_utc_now(), + } + ) + session.commit() - end_at = time.perf_counter() - logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green")) - except Exception: - logger.exception("remove document from index failed") - if not document.archived: - document.enabled = True - db.session.commit() - finally: - redis_client.delete(indexing_cache_key) - db.session.close() + end_at = time.perf_counter() + logger.info( + click.style( + f"Document removed from index: {document.id} latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("remove document from index failed") + if not document.archived: + document.enabled = True + session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 9d208647e6..f20b15ac83 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models import Account, Tenant @@ -29,97 +29,97 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id) """ start_at = time.perf_counter() - try: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) - return - user = db.session.query(Account).where(Account.id == user_id).first() - if not user: - logger.info(click.style(f"User not found: {user_id}", fg="red")) - return - tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first() - if not tenant: - raise ValueError("Tenant not found") - user.current_tenant = tenant + with session_factory.create_session() as session: + try: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + user = session.query(Account).where(Account.id == user_id).first() + if not user: + logger.info(click.style(f"User not found: {user_id}", fg="red")) + return + tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first() + if not tenant: + raise ValueError("Tenant not found") + user.current_tenant = tenant - for document_id in document_ids: - retry_indexing_cache_key = f"document_{document_id}_is_retried" - # check document limit - features = FeatureService.get_features(tenant.id) - try: - if features.billing.enabled: - vector_space = features.vector_space - if 0 < vector_space.limit <= vector_space.size: - raise ValueError( - "Your total number of documents plus the number of uploads have over the limit of " - "your subscription." - ) - except Exception as e: + for document_id in document_ids: + retry_indexing_cache_key = f"document_{document_id}_is_retried" + # check document limit + features = FeatureService.get_features(tenant.id) + try: + if features.billing.enabled: + vector_space = features.vector_space + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + document = ( + session.query(Document) + .where(Document.id == document_id, Document.dataset_id == dataset_id) + .first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + session.add(document) + session.commit() + redis_client.delete(retry_indexing_cache_key) + return + + logger.info(click.style(f"Start retry document: {document_id}", fg="green")) document = ( - db.session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() + session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) - if document: + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) + return + try: + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + session.add(document) + session.commit() + + if dataset.runtime_mode == "rag_pipeline": + rag_pipeline_service = RagPipelineService() + rag_pipeline_service.retry_error_document(dataset, document, user) + else: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) + except Exception as ex: document.indexing_status = "error" - document.error = str(e) + document.error = str(ex) document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - redis_client.delete(retry_indexing_cache_key) - return - - logger.info(click.style(f"Start retry document: {document_id}", fg="green")) - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + session.add(document) + session.commit() + logger.info(click.style(str(ex), fg="yellow")) + redis_client.delete(retry_indexing_cache_key) + logger.exception("retry_document_indexing_task failed, document_id: %s", document_id) + end_at = time.perf_counter() + logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except Exception as e: + logger.exception( + "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids ) - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) - return - try: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) - ).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - db.session.add(document) - db.session.commit() - - if dataset.runtime_mode == "rag_pipeline": - rag_pipeline_service = RagPipelineService() - rag_pipeline_service.retry_error_document(dataset, document, user) - else: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(retry_indexing_cache_key) - except Exception as ex: - document.indexing_status = "error" - document.error = str(ex) - document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - logger.info(click.style(str(ex), fg="yellow")) - redis_client.delete(retry_indexing_cache_key) - logger.exception("retry_document_indexing_task failed, document_id: %s", document_id) - end_at = time.perf_counter() - logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) - except Exception as e: - logger.exception( - "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids - ) - raise e - finally: - db.session.close() + raise e diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 0dc1d841f4..f1c8c56995 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,11 +3,11 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import delete, select +from core.db.session_factory import session_factory from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -27,69 +27,71 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): """ start_at = time.perf_counter() - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if dataset is None: - raise ValueError("Dataset not found") + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") - sync_indexing_cache_key = f"document_{document_id}_is_sync" - # check document limit - features = FeatureService.get_features(dataset.tenant_id) - try: - if features.billing.enabled: - vector_space = features.vector_space - if 0 < vector_space.limit <= vector_space.size: - raise ValueError( - "Your total number of documents plus the number of uploads have over the limit of " - "your subscription." - ) - except Exception as e: - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - if document: + sync_indexing_cache_key = f"document_{document_id}_is_sync" + # check document limit + features = FeatureService.get_features(dataset.tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + document = ( + session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + session.add(document) + session.commit() + redis_client.delete(sync_indexing_cache_key) + return + + logger.info(click.style(f"Start sync website document: {document_id}", fg="green")) + document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + if not document: + logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) + return + try: + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + segment_ids = [segment.id for segment in segments] + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + session.execute(segment_delete_stmt) + session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() + session.add(document) + session.commit() + + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) + except Exception as ex: document.indexing_status = "error" - document.error = str(e) + document.error = str(ex) document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - redis_client.delete(sync_indexing_cache_key) - return - - logger.info(click.style(f"Start sync website document: {document_id}", fg="green")) - document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - if not document: - logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) - return - try: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - - segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - db.session.add(document) - db.session.commit() - - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(sync_indexing_cache_key) - except Exception as ex: - document.indexing_status = "error" - document.error = str(ex) - document.stopped_at = naive_utc_now() - db.session.add(document) - db.session.commit() - logger.info(click.style(str(ex), fg="yellow")) - redis_client.delete(sync_indexing_cache_key) - logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id) - end_at = time.perf_counter() - logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green")) + session.add(document) + session.commit() + logger.info(click.style(str(ex), fg="yellow")) + redis_client.delete(sync_indexing_cache_key) + logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id) + end_at = time.perf_counter() + logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index ee1d31aa91..d18ea2c23c 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -16,6 +16,7 @@ from sqlalchemy import func, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom +from core.db.session_factory import session_factory from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerInvokeEventResponse from core.plugin.impl.exc import PluginInvokeError @@ -27,7 +28,6 @@ from core.trigger.trigger_manager import TriggerManager from core.workflow.enums import NodeType, WorkflowExecutionStatus from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType, unlimited -from extensions.ext_database import db from models.enums import ( AppTriggerType, CreatorUserRole, @@ -257,7 +257,7 @@ def dispatch_triggered_workflow( tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id) ) trigger_entity: TriggerProviderEntity = provider_controller.entity - with Session(db.engine) as session: + with session_factory.create_session() as session: workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers) end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py index ed92f3f3c5..7698a1a6b8 100644 --- a/api/tasks/trigger_subscription_refresh_tasks.py +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -7,9 +7,9 @@ from celery import shared_task from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.utils.locks import build_trigger_refresh_lock_key -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService @@ -92,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None: logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id) try: now: int = _now_ts() - with Session(db.engine) as session: + with session_factory.create_session() as session: subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id) if not subscription: diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 7d145fb50c..3b3c6e5313 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -10,11 +10,10 @@ import logging from celery import shared_task from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from core.db.session_factory import session_factory from core.workflow.entities.workflow_execution import WorkflowExecution from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom @@ -46,10 +45,7 @@ def save_workflow_execution_task( True if successful, False otherwise """ try: - # Create a new session for this task - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - with session_factory() as session: + with session_factory.create_session() as session: # Deserialize execution data execution = WorkflowExecution.model_validate(execution_data) diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 8f5127670f..b30a4ff15b 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -10,13 +10,12 @@ import logging from celery import shared_task from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from core.db.session_factory import session_factory from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -48,10 +47,7 @@ def save_workflow_node_execution_task( True if successful, False otherwise """ try: - # Create a new session for this task - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - with session_factory() as session: + with session_factory.create_session() as session: # Deserialize execution data execution = WorkflowNodeExecution.model_validate(execution_data) diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index f54e02a219..8c64d3ab27 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -1,15 +1,14 @@ import logging from celery import shared_task -from sqlalchemy.orm import sessionmaker +from core.db.session_factory import session_factory from core.workflow.nodes.trigger_schedule.exc import ( ScheduleExecutionError, ScheduleNotFoundError, TenantOwnerNotFoundError, ) from enums.quota_type import QuotaType, unlimited -from extensions.ext_database import db from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError @@ -33,10 +32,7 @@ def run_schedule_trigger(schedule_id: str) -> None: TenantOwnerNotFoundError: If no owner/admin for tenant ScheduleExecutionError: If workflow trigger fails """ - - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - with session_factory() as session: + with session_factory.create_session() as session: schedule = session.get(WorkflowSchedulePlan, schedule_id) if not schedule: raise ScheduleNotFoundError(f"Schedule {schedule_id} not found") diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 7cdc3cb205..f46d1bf5db 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -4,8 +4,8 @@ from unittest.mock import patch import pytest from sqlalchemy import delete +from core.db.session_factory import session_factory from core.variables.segments import StringSegment -from extensions.ext_database import db from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -16,26 +16,23 @@ from tasks.remove_app_and_related_data_task import _delete_draft_variables, dele @pytest.fixture def app_and_tenant(flask_req_ctx): tenant_id = uuid.uuid4() - tenant = Tenant( - id=tenant_id, - name="test_tenant", - ) - db.session.add(tenant) + with session_factory.create_session() as session: + tenant = Tenant(name="test_tenant") + session.add(tenant) + session.flush() - app = App( - tenant_id=tenant_id, # Now tenant.id will have a value - name=f"Test App for tenant {tenant.id}", - mode="workflow", - enable_site=True, - enable_api=True, - ) - db.session.add(app) - db.session.flush() - yield (tenant, app) + app = App( + tenant_id=tenant.id, + name=f"Test App for tenant {tenant.id}", + mode="workflow", + enable_site=True, + enable_api=True, + ) + session.add(app) + session.flush() - # Cleanup with proper error handling - db.session.delete(app) - db.session.delete(tenant) + # return detached objects (ids will be used by tests) + return (tenant, app) class TestDeleteDraftVariablesIntegration: @@ -44,334 +41,285 @@ class TestDeleteDraftVariablesIntegration: """Create test data with apps and draft variables.""" tenant, app = app_and_tenant - # Create a second app for testing - app2 = App( - tenant_id=tenant.id, - name="Test App 2", - mode="workflow", - enable_site=True, - enable_api=True, - ) - db.session.add(app2) - db.session.commit() - - # Create draft variables for both apps - variables_app1 = [] - variables_app2 = [] - - for i in range(5): - var1 = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id=f"node_{i}", - name=f"var_{i}", - value=StringSegment(value="test_value"), - node_execution_id=str(uuid.uuid4()), + with session_factory.create_session() as session: + app2 = App( + tenant_id=tenant.id, + name="Test App 2", + mode="workflow", + enable_site=True, + enable_api=True, ) - db.session.add(var1) - variables_app1.append(var1) + session.add(app2) + session.flush() - var2 = WorkflowDraftVariable.new_node_variable( - app_id=app2.id, - node_id=f"node_{i}", - name=f"var_{i}", - value=StringSegment(value="test_value"), - node_execution_id=str(uuid.uuid4()), - ) - db.session.add(var2) - variables_app2.append(var2) + variables_app1 = [] + variables_app2 = [] + for i in range(5): + var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var1) + variables_app1.append(var1) - # Commit all the variables to the database - db.session.commit() + var2 = WorkflowDraftVariable.new_node_variable( + app_id=app2.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var2) + variables_app2.append(var2) + session.commit() + + app2_id = app2.id yield { "app1": app, - "app2": app2, + "app2": App(id=app2_id), # dummy with id to avoid open session "tenant": tenant, "variables_app1": variables_app1, "variables_app2": variables_app2, } - # Cleanup - refresh session and check if objects still exist - db.session.rollback() # Clear any pending changes - - # Clean up remaining variables - cleanup_query = ( - delete(WorkflowDraftVariable) - .where( - WorkflowDraftVariable.app_id.in_([app.id, app2.id]), + with session_factory.create_session() as session: + cleanup_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.app_id.in_([app.id, app2_id])) + .execution_options(synchronize_session=False) ) - .execution_options(synchronize_session=False) - ) - db.session.execute(cleanup_query) - - # Clean up app2 - app2_obj = db.session.get(App, app2.id) - if app2_obj: - db.session.delete(app2_obj) - - db.session.commit() + session.execute(cleanup_query) + app2_obj = session.get(App, app2_id) + if app2_obj: + session.delete(app2_obj) + session.commit() def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data): - """Test that batch deletion only removes variables for the specified app.""" data = setup_test_data app1_id = data["app1"].id app2_id = data["app2"].id - # Verify initial state - app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + with session_factory.create_session() as session: + app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() assert app1_vars_before == 5 assert app2_vars_before == 5 - # Delete app1 variables deleted_count = delete_draft_variables_batch(app1_id, batch_size=10) - - # Verify results assert deleted_count == 5 - app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() - - assert app1_vars_after == 0 # All app1 variables deleted - assert app2_vars_after == 5 # App2 variables unchanged + with session_factory.create_session() as session: + app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + assert app1_vars_after == 0 + assert app2_vars_after == 5 def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data): - """Test batch deletion with small batch size processes all records.""" data = setup_test_data app1_id = data["app1"].id - # Use small batch size to force multiple batches deleted_count = delete_draft_variables_batch(app1_id, batch_size=2) - assert deleted_count == 5 - # Verify all variables are deleted - remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + with session_factory.create_session() as session: + remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() assert remaining_vars == 0 def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data): - """Test that deleting variables for nonexistent app returns 0.""" - nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format - + nonexistent_app_id = str(uuid.uuid4()) deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100) - assert deleted_count == 0 def test_delete_draft_variables_wrapper_function(self, setup_test_data): - """Test that _delete_draft_variables wrapper function works correctly.""" data = setup_test_data app1_id = data["app1"].id - # Verify initial state - vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + with session_factory.create_session() as session: + vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() assert vars_before == 5 - # Call wrapper function deleted_count = _delete_draft_variables(app1_id) - - # Verify results assert deleted_count == 5 - vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + with session_factory.create_session() as session: + vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() assert vars_after == 0 def test_batch_deletion_handles_large_dataset(self, app_and_tenant): - """Test batch deletion with larger dataset to verify batching logic.""" tenant, app = app_and_tenant - - # Create many draft variables - variables = [] - for i in range(25): - var = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id=f"node_{i}", - name=f"var_{i}", - value=StringSegment(value="test_value"), - node_execution_id=str(uuid.uuid4()), - ) - db.session.add(var) - variables.append(var) - variable_ids = [i.id for i in variables] - - # Commit the variables to the database - db.session.commit() + variable_ids: list[str] = [] + with session_factory.create_session() as session: + variables = [] + for i in range(25): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + variables.append(var) + session.commit() + variable_ids = [v.id for v in variables] try: - # Use small batch size to force multiple batches deleted_count = delete_draft_variables_batch(app.id, batch_size=8) - assert deleted_count == 25 - - # Verify all variables are deleted - remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() - assert remaining_vars == 0 - + with session_factory.create_session() as session: + remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() + assert remaining == 0 finally: - query = ( - delete(WorkflowDraftVariable) - .where( - WorkflowDraftVariable.id.in_(variable_ids), + with session_factory.create_session() as session: + query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) + .execution_options(synchronize_session=False) ) - .execution_options(synchronize_session=False) - ) - db.session.execute(query) + session.execute(query) + session.commit() class TestDeleteDraftVariablesWithOffloadIntegration: - """Integration tests for draft variable deletion with Offload data.""" - @pytest.fixture def setup_offload_test_data(self, app_and_tenant): - """Create test data with draft variables that have associated Offload files.""" tenant, app = app_and_tenant - - # Create UploadFile records + from core.variables.types import SegmentType from libs.datetime_utils import naive_utc_now - upload_file1 = UploadFile( - tenant_id=tenant.id, - storage_type="local", - key="test/file1.json", - name="file1.json", - size=1024, - extension="json", - mime_type="application/json", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=str(uuid.uuid4()), - created_at=naive_utc_now(), - used=False, - ) - upload_file2 = UploadFile( - tenant_id=tenant.id, - storage_type="local", - key="test/file2.json", - name="file2.json", - size=2048, - extension="json", - mime_type="application/json", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=str(uuid.uuid4()), - created_at=naive_utc_now(), - used=False, - ) - db.session.add(upload_file1) - db.session.add(upload_file2) - db.session.flush() + with session_factory.create_session() as session: + upload_file1 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file1.json", + name="file1.json", + size=1024, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + upload_file2 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file2.json", + name="file2.json", + size=2048, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + session.add(upload_file1) + session.add(upload_file2) + session.flush() - # Create WorkflowDraftVariableFile records - from core.variables.types import SegmentType + var_file1 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file1.id, + size=1024, + length=10, + value_type=SegmentType.STRING, + ) + var_file2 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file2.id, + size=2048, + length=20, + value_type=SegmentType.OBJECT, + ) + session.add(var_file1) + session.add(var_file2) + session.flush() - var_file1 = WorkflowDraftVariableFile( - tenant_id=tenant.id, - app_id=app.id, - user_id=str(uuid.uuid4()), - upload_file_id=upload_file1.id, - size=1024, - length=10, - value_type=SegmentType.STRING, - ) - var_file2 = WorkflowDraftVariableFile( - tenant_id=tenant.id, - app_id=app.id, - user_id=str(uuid.uuid4()), - upload_file_id=upload_file2.id, - size=2048, - length=20, - value_type=SegmentType.OBJECT, - ) - db.session.add(var_file1) - db.session.add(var_file2) - db.session.flush() + draft_var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_1", + name="large_var_1", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file1.id, + ) + draft_var2 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_2", + name="large_var_2", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file2.id, + ) + draft_var3 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_3", + name="regular_var", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(draft_var1) + session.add(draft_var2) + session.add(draft_var3) + session.commit() - # Create WorkflowDraftVariable records with file associations - draft_var1 = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id="node_1", - name="large_var_1", - value=StringSegment(value="truncated..."), - node_execution_id=str(uuid.uuid4()), - file_id=var_file1.id, - ) - draft_var2 = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id="node_2", - name="large_var_2", - value=StringSegment(value="truncated..."), - node_execution_id=str(uuid.uuid4()), - file_id=var_file2.id, - ) - # Create a regular variable without Offload data - draft_var3 = WorkflowDraftVariable.new_node_variable( - app_id=app.id, - node_id="node_3", - name="regular_var", - value=StringSegment(value="regular_value"), - node_execution_id=str(uuid.uuid4()), - ) + data = { + "app": app, + "tenant": tenant, + "upload_files": [upload_file1, upload_file2], + "variable_files": [var_file1, var_file2], + "draft_variables": [draft_var1, draft_var2, draft_var3], + } - db.session.add(draft_var1) - db.session.add(draft_var2) - db.session.add(draft_var3) - db.session.commit() + yield data - yield { - "app": app, - "tenant": tenant, - "upload_files": [upload_file1, upload_file2], - "variable_files": [var_file1, var_file2], - "draft_variables": [draft_var1, draft_var2, draft_var3], - } - - # Cleanup - db.session.rollback() - - # Clean up any remaining records - for table, ids in [ - (WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]), - (WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]), - (UploadFile, [uf.id for uf in [upload_file1, upload_file2]]), - ]: - cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) - db.session.execute(cleanup_query) - - db.session.commit() + with session_factory.create_session() as session: + session.rollback() + for table, ids in [ + (WorkflowDraftVariable, [v.id for v in data["draft_variables"]]), + (WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]), + (UploadFile, [uf.id for uf in data["upload_files"]]), + ]: + cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) + session.execute(cleanup_query) + session.commit() @patch("extensions.ext_storage.storage") def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): - """Test that deleting draft variables also cleans up associated Offload data.""" data = setup_offload_test_data app_id = data["app"].id - - # Mock storage deletion to succeed mock_storage.delete.return_value = None - # Verify initial state - draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = db.session.query(WorkflowDraftVariableFile).count() - upload_files_before = db.session.query(UploadFile).count() - - assert draft_vars_before == 3 # 2 with files + 1 regular + with session_factory.create_session() as session: + draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_before = session.query(WorkflowDraftVariableFile).count() + upload_files_before = session.query(UploadFile).count() + assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 - # Delete draft variables deleted_count = delete_draft_variables_batch(app_id, batch_size=10) - - # Verify results assert deleted_count == 3 - # Check that all draft variables are deleted - draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() assert draft_vars_after == 0 - # Check that associated Offload data is cleaned up - var_files_after = db.session.query(WorkflowDraftVariableFile).count() - upload_files_after = db.session.query(UploadFile).count() + with session_factory.create_session() as session: + var_files_after = session.query(WorkflowDraftVariableFile).count() + upload_files_after = session.query(UploadFile).count() + assert var_files_after == 0 + assert upload_files_after == 0 - assert var_files_after == 0 # All variable files should be deleted - assert upload_files_after == 0 # All upload files should be deleted - - # Verify storage deletion was called for both files assert mock_storage.delete.call_count == 2 storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list] assert "test/file1.json" in storage_keys_deleted @@ -379,92 +327,71 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @patch("extensions.ext_storage.storage") def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): - """Test that database cleanup continues even when storage deletion fails.""" data = setup_offload_test_data app_id = data["app"].id - - # Mock storage deletion to fail for first file, succeed for second mock_storage.delete.side_effect = [Exception("Storage error"), None] - # Delete draft variables deleted_count = delete_draft_variables_batch(app_id, batch_size=10) - - # Verify that all draft variables are still deleted assert deleted_count == 3 - draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() assert draft_vars_after == 0 - # Database cleanup should still succeed even with storage errors - var_files_after = db.session.query(WorkflowDraftVariableFile).count() - upload_files_after = db.session.query(UploadFile).count() - + with session_factory.create_session() as session: + var_files_after = session.query(WorkflowDraftVariableFile).count() + upload_files_after = session.query(UploadFile).count() assert var_files_after == 0 assert upload_files_after == 0 - # Verify storage deletion was attempted for both files assert mock_storage.delete.call_count == 2 @patch("extensions.ext_storage.storage") def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data): - """Test deletion with mix of variables with and without Offload data.""" data = setup_offload_test_data app_id = data["app"].id - - # Create additional app with only regular variables (no offload data) tenant = data["tenant"] - app2 = App( - tenant_id=tenant.id, - name="Test App 2", - mode="workflow", - enable_site=True, - enable_api=True, - ) - db.session.add(app2) - db.session.flush() - # Add regular variables to app2 - regular_vars = [] - for i in range(3): - var = WorkflowDraftVariable.new_node_variable( - app_id=app2.id, - node_id=f"node_{i}", - name=f"var_{i}", - value=StringSegment(value="regular_value"), - node_execution_id=str(uuid.uuid4()), + with session_factory.create_session() as session: + app2 = App( + tenant_id=tenant.id, + name="Test App 2", + mode="workflow", + enable_site=True, + enable_api=True, ) - db.session.add(var) - regular_vars.append(var) - db.session.commit() + session.add(app2) + session.flush() + + for i in range(3): + var = WorkflowDraftVariable.new_node_variable( + app_id=app2.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + session.commit() try: - # Mock storage deletion mock_storage.delete.return_value = None - - # Delete variables for app2 (no offload data) deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10) assert deleted_count_app2 == 3 - - # Verify storage wasn't called for app2 (no offload files) mock_storage.delete.assert_not_called() - # Delete variables for original app (with offload data) deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10) assert deleted_count_app1 == 3 - - # Now storage should be called for the offload files assert mock_storage.delete.call_count == 2 - finally: - # Cleanup app2 and its variables - cleanup_vars_query = ( - delete(WorkflowDraftVariable) - .where(WorkflowDraftVariable.app_id == app2.id) - .execution_options(synchronize_session=False) - ) - db.session.execute(cleanup_vars_query) - - app2_obj = db.session.get(App, app2.id) - if app2_obj: - db.session.delete(app2_obj) - db.session.commit() + with session_factory.create_session() as session: + cleanup_vars_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.app_id == app2.id) + .execution_options(synchronize_session=False) + ) + session.execute(cleanup_vars_query) + app2_obj = session.get(App, app2.id) + if app2_obj: + session.delete(app2_obj) + session.commit() diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 9b0bd6275b..1a9d69b2d2 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -5,13 +5,13 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import NodeRunResult from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index d814da8ec7..1bcac3b5fe 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,11 +5,11 @@ from urllib.parse import urlencode import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.http_request.node import HttpRequestNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d268c5da22..c361bfcc6f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -5,13 +5,13 @@ from collections.abc import Generator from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 654db59bec..7445699a86 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -4,11 +4,11 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 3bcb9a3a34..bc03ce1b96 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -4,10 +4,10 @@ import uuid import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index d666f0ebe2..cfbef52c93 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -3,12 +3,12 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.tool.tool_node import ToolNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 5555400ca6..4f5190e533 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -220,6 +220,23 @@ class TestAnnotationService: # Note: In this test, no annotation setting exists, so task should not be called mock_external_service_dependencies["add_task"].delay.assert_not_called() + def test_insert_app_annotation_directly_requires_question( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Question must be provided when inserting annotations directly. + """ + fake = Faker() + app, _ = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + annotation_args = { + "question": None, + "answer": fake.text(max_nb_chars=200), + } + + with pytest.raises(ValueError): + AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + def test_insert_app_annotation_directly_app_not_found( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index 40380b09d2..bd2fd14ffa 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -4,7 +4,13 @@ import pytest from faker import Faker from enums.cloud_plan import CloudPlan -from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel +from services.feature_service import ( + FeatureModel, + FeatureService, + KnowledgeRateLimitModel, + LicenseStatus, + SystemFeatureModel, +) class TestFeatureService: @@ -274,7 +280,7 @@ class TestFeatureService: mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None @@ -324,6 +330,61 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval for an unauthenticated user. + + This test verifies that: + - The response payload is minimized (e.g., verbose license details are excluded). + - Essential UI configuration (Branding, SSO, Marketplace) remains available. + - The response structure adheres to the public schema for unauthenticated clients. + """ + # Arrange: Setup test data with exact same config as success test + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute with is_authenticated=False + result = FeatureService.get_system_features(is_authenticated=False) + + # Assert: Basic structure + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # --- 1. Verify Response Payload Optimization (Data Minimization) --- + # Ensure only essential UI flags are returned to unauthenticated clients + # to keep the payload lightweight and adhere to architectural boundaries. + assert result.license.status == LicenseStatus.NONE + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.limit == 0 + assert result.license.workspaces.size == 0 + + # --- 2. Verify Public UI Configuration Availability --- + # Ensure that data required for frontend rendering remains accessible. + + # Branding should match the mock data + assert result.branding.enabled is True + assert result.branding.application_title == "Test Enterprise" + assert result.branding.login_page_logo == "https://example.com/logo.png" + + # SSO settings should be visible for login page rendering + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "saml" + + # General auth settings should be visible + assert result.enable_email_code_login is True + + # Marketplace should be visible + assert result.enable_marketplace is True + def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): """ Test system features retrieval with basic configuration (no enterprise). @@ -1031,7 +1092,7 @@ class TestFeatureService: } # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None @@ -1400,7 +1461,7 @@ class TestFeatureService: } # Act: Execute the method under test - result = FeatureService.get_system_features() + result = FeatureService.get_system_features(is_authenticated=True) # Assert: Verify the expected outcomes assert result is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 9297e997e9..09407f7686 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -39,23 +39,22 @@ class TestCleanDatasetTask: @pytest.fixture(autouse=True) def cleanup_database(self, db_session_with_containers): """Clean up database before each test to ensure isolation.""" - from extensions.ext_database import db from extensions.ext_redis import redis_client - # Clear all test data - db.session.query(DatasetMetadataBinding).delete() - db.session.query(DatasetMetadata).delete() - db.session.query(AppDatasetJoin).delete() - db.session.query(DatasetQuery).delete() - db.session.query(DatasetProcessRule).delete() - db.session.query(DocumentSegment).delete() - db.session.query(Document).delete() - db.session.query(Dataset).delete() - db.session.query(UploadFile).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + # Clear all test data using the provided session fixture + db_session_with_containers.query(DatasetMetadataBinding).delete() + db_session_with_containers.query(DatasetMetadata).delete() + db_session_with_containers.query(AppDatasetJoin).delete() + db_session_with_containers.query(DatasetQuery).delete() + db_session_with_containers.query(DatasetProcessRule).delete() + db_session_with_containers.query(DocumentSegment).delete() + db_session_with_containers.query(Document).delete() + db_session_with_containers.query(Dataset).delete() + db_session_with_containers.query(UploadFile).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() # Clear Redis cache redis_client.flushdb() @@ -103,10 +102,8 @@ class TestCleanDatasetTask: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -115,8 +112,8 @@ class TestCleanDatasetTask: status="active", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account relationship tenant_account_join = TenantAccountJoin( @@ -125,8 +122,8 @@ class TestCleanDatasetTask: role=TenantAccountRole.OWNER, ) - db.session.add(tenant_account_join) - db.session.commit() + db_session_with_containers.add(tenant_account_join) + db_session_with_containers.commit() return account, tenant @@ -155,10 +152,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @@ -194,10 +189,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document @@ -232,10 +225,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment @@ -267,10 +258,8 @@ class TestCleanDatasetTask: used=False, ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file @@ -302,31 +291,29 @@ class TestCleanDatasetTask: ) # Verify results - from extensions.ext_database import db - # Check that dataset-related data was cleaned up - documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(documents) == 0 - segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(segments) == 0 # Check that metadata and bindings were cleaned up - metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() assert len(metadata) == 0 - bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() assert len(bindings) == 0 # Check that process rules and queries were cleaned up - process_rules = db.session.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() + process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() assert len(process_rules) == 0 - queries = db.session.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() + queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() assert len(queries) == 0 # Check that app dataset joins were cleaned up - app_joins = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() + app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() assert len(app_joins) == 0 # Verify index processor was called @@ -378,9 +365,7 @@ class TestCleanDatasetTask: import json document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Create dataset metadata and bindings metadata = DatasetMetadata( @@ -403,11 +388,9 @@ class TestCleanDatasetTask: binding.id = str(uuid.uuid4()) binding.created_at = datetime.now() - from extensions.ext_database import db - - db.session.add(metadata) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(metadata) + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Execute the task clean_dataset_task( @@ -421,22 +404,24 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() assert len(remaining_files) == 0 # Check that metadata and bindings were cleaned up - remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() assert len(remaining_metadata) == 0 - remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + remaining_bindings = ( + db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + ) assert len(remaining_bindings) == 0 # Verify index processor was called @@ -489,12 +474,13 @@ class TestCleanDatasetTask: mock_index_processor.clean.assert_called_once() # Check that all data was cleaned up - from extensions.ext_database import db - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = ( + db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + ) assert len(remaining_segments) == 0 # Recreate data for next test case @@ -540,14 +526,13 @@ class TestCleanDatasetTask: ) # Verify results - even with vector cleanup failure, documents and segments should be deleted - from extensions.ext_database import db # Check that documents were still deleted despite vector cleanup failure - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that segments were still deleted despite vector cleanup failure - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Verify that index processor was called and failed @@ -608,10 +593,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Mock the get_image_upload_file_ids function to return our image file IDs with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids: @@ -629,16 +612,18 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Check that all image files were deleted from database image_file_ids = [f.id for f in image_files] - remaining_image_files = db.session.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() + remaining_image_files = ( + db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() + ) assert len(remaining_image_files) == 0 # Verify that storage.delete was called for each image file @@ -745,22 +730,24 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() assert len(remaining_files) == 0 # Check that all metadata and bindings were deleted - remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() assert len(remaining_metadata) == 0 - remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + remaining_bindings = ( + db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + ) assert len(remaining_bindings) == 0 # Verify performance expectations @@ -808,9 +795,7 @@ class TestCleanDatasetTask: import json document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock storage to raise exceptions mock_storage = mock_external_service_dependencies["storage"] @@ -827,18 +812,13 @@ class TestCleanDatasetTask: ) # Verify results - # Check that documents were still deleted despite storage failure - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() - assert len(remaining_documents) == 0 - - # Check that segments were still deleted despite storage failure - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() - assert len(remaining_segments) == 0 + # Note: When storage operations fail, database deletions may be rolled back by implementation. + # This test focuses on ensuring the task handles the exception and continues execution/logging. # Check that upload file was still deleted from database despite storage failure # Note: When storage operations fail, the upload file may not be deleted # This demonstrates that the cleanup process continues even with storage errors - remaining_files = db.session.query(UploadFile).filter_by(id=upload_file.id).all() + remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all() # The upload file should still be deleted from the database even if storage cleanup fails # However, this depends on the specific implementation of clean_dataset_task if len(remaining_files) > 0: @@ -890,10 +870,8 @@ class TestCleanDatasetTask: updated_at=datetime.now(), ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document with special characters in name special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~" @@ -912,8 +890,8 @@ class TestCleanDatasetTask: created_at=datetime.now(), updated_at=datetime.now(), ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Create segment with special characters and very long content long_content = "Very long content " * 100 # Long content within reasonable limits @@ -934,8 +912,8 @@ class TestCleanDatasetTask: created_at=datetime.now(), updated_at=datetime.now(), ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Create upload file with special characters in name special_filename = f"test_file_{special_content}.txt" @@ -952,14 +930,14 @@ class TestCleanDatasetTask: created_at=datetime.now(), used=False, ) - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() # Update document with file reference import json document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Save upload file ID for verification upload_file_id = upload_file.id @@ -975,8 +953,8 @@ class TestCleanDatasetTask: special_metadata.id = str(uuid.uuid4()) special_metadata.created_at = datetime.now() - db.session.add(special_metadata) - db.session.commit() + db_session_with_containers.add(special_metadata) + db_session_with_containers.commit() # Execute the task clean_dataset_task( @@ -990,19 +968,19 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all() + remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all() assert len(remaining_files) == 0 # Check that all metadata was deleted - remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() assert len(remaining_metadata) == 0 # Verify that storage.delete was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 8004175b2d..caa5ee3851 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -24,16 +24,15 @@ class TestCreateSegmentToIndexTask: @pytest.fixture(autouse=True) def cleanup_database(self, db_session_with_containers): """Clean up database and Redis before each test to ensure isolation.""" - from extensions.ext_database import db - # Clear all test data - db.session.query(DocumentSegment).delete() - db.session.query(Document).delete() - db.session.query(Dataset).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + # Clear all test data using fixture session + db_session_with_containers.query(DocumentSegment).delete() + db_session_with_containers.query(Document).delete() + db_session_with_containers.query(Dataset).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() # Clear Redis cache redis_client.flushdb() @@ -73,10 +72,8 @@ class TestCreateSegmentToIndexTask: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -84,8 +81,8 @@ class TestCreateSegmentToIndexTask: status="normal", plan="basic", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -94,8 +91,8 @@ class TestCreateSegmentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -746,20 +743,9 @@ class TestCreateSegmentToIndexTask: db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" ) - # Mock global database session to simulate transaction issues - from extensions.ext_database import db - - original_commit = db.session.commit - commit_called = False - - def mock_commit(): - nonlocal commit_called - if not commit_called: - commit_called = True - raise Exception("Database commit failed") - return original_commit() - - db.session.commit = mock_commit + # Simulate an error during indexing to trigger rollback path + mock_processor = mock_external_service_dependencies["index_processor"] + mock_processor.load.side_effect = Exception("Simulated indexing error") # Act: Execute the task create_segment_to_index_task(segment.id) @@ -771,9 +757,6 @@ class TestCreateSegmentToIndexTask: assert segment.disabled_at is not None assert segment.error is not None - # Restore original commit method - db.session.commit = original_commit - def test_create_segment_to_index_metadata_validation( self, db_session_with_containers, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 0b36e0914a..56b53a24b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -70,11 +70,9 @@ class TestDisableSegmentsFromIndexTask: tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at - from extensions.ext_database import db - - db.session.add(tenant) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.add(account) + db_session_with_containers.commit() # Set the current tenant for the account account.current_tenant = tenant @@ -110,10 +108,8 @@ class TestDisableSegmentsFromIndexTask: built_in_field_enabled=False, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @@ -158,10 +154,8 @@ class TestDisableSegmentsFromIndexTask: document.archived = False document.doc_form = "text_model" # Use text_model form for testing document.doc_language = "en" - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document @@ -211,11 +205,9 @@ class TestDisableSegmentsFromIndexTask: segments.append(segment) - from extensions.ext_database import db - for segment in segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segments @@ -645,15 +637,12 @@ class TestDisableSegmentsFromIndexTask: with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: mock_redis.delete.return_value = True - # Mock db.session.close to verify it's called - with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close: - # Act - result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) - # Assert - assert result is None # Task should complete without returning a value - # Verify session was closed - mock_close.assert_called() + # Assert + assert result is None # Task should complete without returning a value + # Session lifecycle is managed by context manager; no explicit close assertion def test_disable_segments_empty_segment_ids(self, db_session_with_containers): """ diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index c015d7ec9c..0d266e7e76 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -6,7 +6,6 @@ from faker import Faker from core.entities.document_task import DocumentTask from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document from tasks.document_indexing_task import ( @@ -75,15 +74,15 @@ class TestDocumentIndexingTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -92,8 +91,8 @@ class TestDocumentIndexingTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -105,8 +104,8 @@ class TestDocumentIndexingTasks: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create documents documents = [] @@ -124,13 +123,13 @@ class TestDocumentIndexingTasks: indexing_status="waiting", enabled=True, ) - db.session.add(document) + db_session_with_containers.add(document) documents.append(document) - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure it's properly loaded - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, documents @@ -157,15 +156,15 @@ class TestDocumentIndexingTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -174,8 +173,8 @@ class TestDocumentIndexingTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -187,8 +186,8 @@ class TestDocumentIndexingTasks: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create documents documents = [] @@ -206,10 +205,10 @@ class TestDocumentIndexingTasks: indexing_status="waiting", enabled=True, ) - db.session.add(document) + db_session_with_containers.add(document) documents.append(document) - db.session.commit() + db_session_with_containers.commit() # Configure billing features mock_external_service_dependencies["features"].billing.enabled = billing_enabled @@ -219,7 +218,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["features"].vector_space.size = 50 # Refresh dataset to ensure it's properly loaded - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, documents @@ -242,6 +241,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify the expected outcomes # Verify indexing runner was called correctly mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -250,7 +252,7 @@ class TestDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -310,6 +312,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task with mixed document IDs _document_indexing(dataset.id, all_document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify only existing documents were processed mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -317,7 +322,7 @@ class TestDocumentIndexingTasks: # Verify only existing documents were updated # Re-query documents from database since _document_indexing uses a different session for doc_id in existing_document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -353,6 +358,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -361,7 +369,7 @@ class TestDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _document_indexing close the session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -400,7 +408,7 @@ class TestDocumentIndexingTasks: indexing_status="completed", # Already completed enabled=True, ) - db.session.add(doc1) + db_session_with_containers.add(doc1) extra_documents.append(doc1) # Document with disabled status @@ -417,10 +425,10 @@ class TestDocumentIndexingTasks: indexing_status="waiting", enabled=False, # Disabled ) - db.session.add(doc2) + db_session_with_containers.add(doc2) extra_documents.append(doc2) - db.session.commit() + db_session_with_containers.commit() all_documents = base_documents + extra_documents document_ids = [doc.id for doc in all_documents] @@ -428,6 +436,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task with mixed document states _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify processing mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -435,7 +446,7 @@ class TestDocumentIndexingTasks: # Verify all documents were updated to parsing status # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -482,20 +493,23 @@ class TestDocumentIndexingTasks: indexing_status="waiting", enabled=True, ) - db.session.add(document) + db_session_with_containers.add(document) extra_documents.append(document) - db.session.commit() + db_session_with_containers.commit() all_documents = documents + extra_documents document_ids = [doc.id for doc in all_documents] # Act: Execute the task with too many documents for sandbox plan _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify error handling # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "error" assert updated_document.error is not None assert "batch upload" in updated_document.error @@ -526,6 +540,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task with billing disabled _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify successful processing mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -533,7 +550,7 @@ class TestDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -565,6 +582,9 @@ class TestDocumentIndexingTasks: # Act: Execute the task _document_indexing(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -573,7 +593,7 @@ class TestDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -674,6 +694,9 @@ class TestDocumentIndexingTasks: # Act: Execute the wrapper function _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify core processing occurred (same as _document_indexing) mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -681,7 +704,7 @@ class TestDocumentIndexingTasks: # Verify documents were updated (same as _document_indexing) # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -794,6 +817,9 @@ class TestDocumentIndexingTasks: # Act: Execute the wrapper function _document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify error was handled gracefully # The function should not raise exceptions mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -802,7 +828,7 @@ class TestDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -865,6 +891,9 @@ class TestDocumentIndexingTasks: # Act: Execute the wrapper function for tenant1 only _document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify core processing occurred for tenant1 mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index aca4be1ffd..fbcee899e1 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -4,7 +4,6 @@ import pytest from faker import Faker from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment from tasks.duplicate_document_indexing_task import ( @@ -82,15 +81,15 @@ class TestDuplicateDocumentIndexingTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -99,8 +98,8 @@ class TestDuplicateDocumentIndexingTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -112,8 +111,8 @@ class TestDuplicateDocumentIndexingTasks: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create documents documents = [] @@ -132,13 +131,13 @@ class TestDuplicateDocumentIndexingTasks: enabled=True, doc_form="text_model", ) - db.session.add(document) + db_session_with_containers.add(document) documents.append(document) - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure it's properly loaded - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, documents @@ -183,14 +182,14 @@ class TestDuplicateDocumentIndexingTasks: indexing_at=fake.date_time_this_year(), created_by=dataset.created_by, # Add required field ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() # Refresh to ensure all relationships are loaded for document in documents: - db.session.refresh(document) + db_session_with_containers.refresh(document) return dataset, documents, segments @@ -217,15 +216,15 @@ class TestDuplicateDocumentIndexingTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -234,8 +233,8 @@ class TestDuplicateDocumentIndexingTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -247,8 +246,8 @@ class TestDuplicateDocumentIndexingTasks: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create documents documents = [] @@ -267,10 +266,10 @@ class TestDuplicateDocumentIndexingTasks: enabled=True, doc_form="text_model", ) - db.session.add(document) + db_session_with_containers.add(document) documents.append(document) - db.session.commit() + db_session_with_containers.commit() # Configure billing features mock_external_service_dependencies["features"].billing.enabled = billing_enabled @@ -280,7 +279,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["features"].vector_space.size = 50 # Refresh dataset to ensure it's properly loaded - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, documents @@ -305,6 +304,9 @@ class TestDuplicateDocumentIndexingTasks: # Act: Execute the task _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify the expected outcomes # Verify indexing runner was called correctly mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -313,7 +315,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -340,23 +342,32 @@ class TestDuplicateDocumentIndexingTasks: db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3 ) document_ids = [doc.id for doc in documents] + segment_ids = [seg.id for seg in segments] # Act: Execute the task _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + + # Assert: Verify segment cleanup + db_session_with_containers.expire_all() + # Assert: Verify segment cleanup # Verify index processor clean was called for each document with segments assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents) # Verify segments were deleted from database - # Re-query segments from database since _duplicate_document_indexing_task uses a different session - for segment in segments: - deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + # Re-query segments from database using captured IDs to avoid stale ORM instances + for seg_id in segment_ids: + deleted_segment = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first() + ) assert deleted_segment is None # Verify documents were updated to parsing status for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -415,6 +426,9 @@ class TestDuplicateDocumentIndexingTasks: # Act: Execute the task with mixed document IDs _duplicate_document_indexing_task(dataset.id, all_document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify only existing documents were processed mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() @@ -422,7 +436,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify only existing documents were updated # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -458,6 +472,9 @@ class TestDuplicateDocumentIndexingTasks: # Act: Execute the task _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify exception was handled gracefully # The task should complete without raising exceptions mock_external_service_dependencies["indexing_runner"].assert_called_once() @@ -466,7 +483,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" assert updated_document.processing_started_at is not None @@ -508,20 +525,23 @@ class TestDuplicateDocumentIndexingTasks: enabled=True, doc_form="text_model", ) - db.session.add(document) + db_session_with_containers.add(document) extra_documents.append(document) - db.session.commit() + db_session_with_containers.commit() all_documents = documents + extra_documents document_ids = [doc.id for doc in all_documents] # Act: Execute the task with too many documents for sandbox plan _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "error" assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() @@ -557,10 +577,13 @@ class TestDuplicateDocumentIndexingTasks: # Act: Execute the task with documents that will exceed vector space limit _duplicate_document_indexing_task(dataset.id, document_ids) + # Ensure we see committed changes from a different session + db_session_with_containers.expire_all() + # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "error" assert updated_document.error is not None assert "limit" in updated_document.error.lower() @@ -620,11 +643,11 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Clear session cache to see database updates from task's session - db.session.expire_all() + db_session_with_containers.expire_all() # Verify documents were processed for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") @@ -663,11 +686,11 @@ class TestDuplicateDocumentIndexingTasks: mock_queue.delete_task_key.assert_called_once() # Clear session cache to see database updates from task's session - db.session.expire_all() + db_session_with_containers.expire_all() # Verify documents were processed for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") @@ -707,11 +730,11 @@ class TestDuplicateDocumentIndexingTasks: mock_queue.delete_task_key.assert_called_once() # Clear session cache to see database updates from task's session - db.session.expire_all() + db_session_with_containers.expire_all() # Verify documents were processed for doc_id in document_ids: - updated_document = db.session.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() assert updated_document.indexing_status == "parsing" @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index 217c39385c..8dd669e17f 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -1,10 +1,13 @@ """Tests for execution context module.""" import contextvars +import threading +from contextlib import contextmanager from typing import Any from unittest.mock import MagicMock import pytest +from pydantic import BaseModel from core.workflow.context.execution_context import ( AppContext, @@ -12,6 +15,8 @@ from core.workflow.context.execution_context import ( ExecutionContextBuilder, IExecutionContext, NullAppContext, + read_context, + register_context, ) @@ -146,6 +151,54 @@ class TestExecutionContext: assert ctx.user == user + def test_thread_safe_context_manager(self): + """Test shared ExecutionContext works across threads without token mismatch.""" + test_var = contextvars.ContextVar("thread_safe_test_var") + + class TrackingAppContext(AppContext): + def get_config(self, key: str, default: Any = None) -> Any: + return default + + def get_extension(self, name: str) -> Any: + return None + + @contextmanager + def enter(self): + token = test_var.set(threading.get_ident()) + try: + yield + finally: + test_var.reset(token) + + ctx = ExecutionContext(app_context=TrackingAppContext()) + errors: list[Exception] = [] + barrier = threading.Barrier(2) + + def worker(): + try: + for _ in range(20): + with ctx: + try: + barrier.wait() + barrier.wait() + except threading.BrokenBarrierError: + return + except Exception as exc: + errors.append(exc) + try: + barrier.abort() + except Exception: + pass + + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=worker) + t1.start() + t2.start() + t1.join(timeout=5) + t2.join(timeout=5) + + assert not errors + class TestIExecutionContextProtocol: """Test IExecutionContext protocol.""" @@ -256,3 +309,31 @@ class TestCaptureCurrentContext: # Context variables should be captured assert result.context_vars is not None + + +class TestTenantScopedContextRegistry: + def setup_method(self): + from core.workflow.context import reset_context_provider + + reset_context_provider() + + def teardown_method(self): + from core.workflow.context import reset_context_provider + + reset_context_provider() + + def test_tenant_provider_read_ok(self): + class SandboxContext(BaseModel): + base_url: str | None = None + + register_context("workflow.sandbox", "t1", lambda: SandboxContext(base_url="http://t1")) + register_context("workflow.sandbox", "t2", lambda: SandboxContext(base_url="http://t2")) + + assert read_context("workflow.sandbox", tenant_id="t1").base_url == "http://t1" + assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" + + def test_missing_provider_raises_keyerror(self): + from core.workflow.context import ContextProviderNotFoundError + + with pytest.raises(ContextProviderNotFoundError): + read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 6e9a432745..170445225b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,9 +7,9 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.enums import NodeType from core.workflow.nodes.base.node import Node -from core.workflow.nodes.node_factory import DifyNodeFactory from .test_mock_nodes import ( MockAgentNode, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index b76fe42fce..e8cd665107 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -13,6 +13,7 @@ from unittest.mock import patch from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph @@ -26,7 +27,6 @@ from core.workflow.graph_events import ( ) from core.workflow.node_events import NodeRunResult, StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 08f7b00a33..10ac1206fb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,6 +19,7 @@ from functools import lru_cache from pathlib import Path from typing import Any +from core.app.workflow.node_factory import DifyNodeFactory from core.tools.utils.yaml_utils import _load_yaml_file from core.variables import ( ArrayNumberVariable, @@ -38,7 +39,6 @@ from core.workflow.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 98d9560e64..1e95ec1970 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -3,11 +3,11 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index dc7175f964..d700888c2f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, Mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment from core.workflow.entities import GraphInitParams @@ -12,7 +13,6 @@ from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 1df75380af..d4b7a017f9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -3,11 +3,11 @@ import uuid from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.variables import ArrayStringVariable, StringVariable from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 353d56fe25..b08f9c37b4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -3,10 +3,10 @@ import uuid from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.workflow.node_factory import DifyNodeFactory from core.variables import ArrayStringVariable from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph -from core.workflow.nodes.node_factory import DifyNodeFactory from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation from core.workflow.runtime import GraphRuntimeState, VariablePool diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py index 7cdb2328f2..078ec5f6ab 100644 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -30,3 +30,12 @@ class TestWorkflowExecutionStatus: for status in non_ended_statuses: assert not status.is_ended(), f"{status} should not be considered ended" + + def test_ended_values(self): + """Test ended_values returns the expected status values.""" + assert set(WorkflowExecutionStatus.ended_values()) == { + WorkflowExecutionStatus.SUCCEEDED.value, + WorkflowExecutionStatus.FAILED.value, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, + WorkflowExecutionStatus.STOPPED.value, + } diff --git a/api/tests/unit_tests/libs/test_archive_storage.py b/api/tests/unit_tests/libs/test_archive_storage.py index 697760e33a..de3c9c4737 100644 --- a/api/tests/unit_tests/libs/test_archive_storage.py +++ b/api/tests/unit_tests/libs/test_archive_storage.py @@ -37,6 +37,20 @@ def _client_error(code: str) -> ClientError: def _mock_client(monkeypatch): client = MagicMock() client.head_bucket.return_value = None + # Configure put_object to return a proper ETag that matches the MD5 hash + # The ETag format is typically the MD5 hash wrapped in quotes + + def mock_put_object(**kwargs): + md5_hash = kwargs.get("Body", b"") + if isinstance(md5_hash, bytes): + md5_hash = hashlib.md5(md5_hash).hexdigest() + else: + md5_hash = hashlib.md5(md5_hash.encode()).hexdigest() + response = MagicMock() + response.get.return_value = f'"{md5_hash}"' + return response + + client.put_object.side_effect = mock_put_object boto_client = MagicMock(return_value=client) monkeypatch.setattr(storage_module.boto3, "client", boto_client) return client, boto_client @@ -254,8 +268,8 @@ def test_serialization_roundtrip(): {"id": "2", "value": 123}, ] - data = ArchiveStorage.serialize_to_jsonl_gz(records) - decoded = ArchiveStorage.deserialize_from_jsonl_gz(data) + data = ArchiveStorage.serialize_to_jsonl(records) + decoded = ArchiveStorage.deserialize_from_jsonl(data) assert decoded[0]["id"] == "1" assert decoded[0]["payload"]["nested"] == "value" diff --git a/api/tests/unit_tests/services/test_archive_workflow_run_logs.py b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py new file mode 100644 index 0000000000..ef62dacd6b --- /dev/null +++ b/api/tests/unit_tests/services/test_archive_workflow_run_logs.py @@ -0,0 +1,54 @@ +""" +Unit tests for workflow run archiving functionality. + +This module contains tests for: +- Archive service +- Rollback service +""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME + + +class TestWorkflowRunArchiver: + """Tests for the WorkflowRunArchiver class.""" + + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") + @patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage") + def test_archiver_initialization(self, mock_get_storage, mock_config): + """Test archiver can be initialized with various options.""" + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + mock_config.BILLING_ENABLED = False + + archiver = WorkflowRunArchiver( + days=90, + batch_size=100, + tenant_ids=["test-tenant"], + limit=50, + dry_run=True, + ) + + assert archiver.days == 90 + assert archiver.batch_size == 100 + assert archiver.tenant_ids == ["test-tenant"] + assert archiver.limit == 50 + assert archiver.dry_run is True + + def test_get_archive_key(self): + """Test archive key generation.""" + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + archiver = WorkflowRunArchiver.__new__(WorkflowRunArchiver) + + mock_run = MagicMock() + mock_run.tenant_id = "tenant-123" + mock_run.app_id = "app-999" + mock_run.id = "run-456" + mock_run.created_at = datetime(2024, 1, 15, 12, 0, 0) + + key = archiver._get_archive_key(mock_run) + + assert key == f"tenant-123/app_id=app-999/year=2024/month=01/workflow_run_id=run-456/{ARCHIVE_BUNDLE_NAME}" diff --git a/api/tests/unit_tests/services/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py new file mode 100644 index 0000000000..2c9d946ea6 --- /dev/null +++ b/api/tests/unit_tests/services/test_delete_archived_workflow_run.py @@ -0,0 +1,180 @@ +""" +Unit tests for archived workflow run deletion service. +""" + +from unittest.mock import MagicMock, patch + + +class TestArchivedWorkflowRunDeletion: + def test_delete_by_run_id_returns_error_when_run_missing(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + session = MagicMock() + session.get.return_value = None + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is False + assert result.error == "Workflow run run-1 not found" + repo.get_archived_run_ids.assert_not_called() + + def test_delete_by_run_id_returns_error_when_not_archived(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + repo.get_archived_run_ids.return_value = set() + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + session = MagicMock() + session.get.return_value = run + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object(deleter, "_delete_run") as mock_delete_run, + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is False + assert result.error == "Workflow run run-1 is not archived" + mock_delete_run.assert_not_called() + + def test_delete_by_run_id_calls_delete_run(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + repo.get_archived_run_ids.return_value = {"run-1"} + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + session = MagicMock() + session.get.return_value = run + + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run, + ): + result = deleter.delete_by_run_id("run-1") + + assert result.success is True + mock_delete_run.assert_called_once_with(run) + + def test_delete_batch_uses_repo(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + repo = MagicMock() + run1 = MagicMock() + run1.id = "run-1" + run1.tenant_id = "tenant-1" + run2 = MagicMock() + run2.id = "run-2" + run2.tenant_id = "tenant-1" + repo.get_archived_runs_by_time_range.return_value = [run1, run2] + + session = MagicMock() + session_maker = MagicMock() + session_maker.return_value.__enter__.return_value = session + session_maker.return_value.__exit__.return_value = None + start_date = MagicMock() + end_date = MagicMock() + mock_db = MagicMock() + mock_db.engine = MagicMock() + + with ( + patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db), + patch( + "services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker + ), + patch.object(deleter, "_get_workflow_run_repo", return_value=repo), + patch.object( + deleter, "_delete_run", side_effect=[MagicMock(success=True), MagicMock(success=True)] + ) as mock_delete_run, + ): + results = deleter.delete_batch( + tenant_ids=["tenant-1"], + start_date=start_date, + end_date=end_date, + limit=2, + ) + + assert len(results) == 2 + repo.get_archived_runs_by_time_range.assert_called_once_with( + session=session, + tenant_ids=["tenant-1"], + start_date=start_date, + end_date=end_date, + limit=2, + ) + assert mock_delete_run.call_count == 2 + + def test_delete_run_dry_run(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + result = deleter._delete_run(run) + + assert result.success is True + mock_get_repo.assert_not_called() + + def test_delete_run_calls_repo(self): + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + deleter = ArchivedWorkflowRunDeletion() + run = MagicMock() + run.id = "run-1" + run.tenant_id = "tenant-1" + + repo = MagicMock() + repo.delete_runs_with_related.return_value = {"runs": 1} + + with patch.object(deleter, "_get_workflow_run_repo", return_value=repo): + result = deleter._delete_run(run) + + assert result.success is True + assert result.deleted_counts == {"runs": 1} + repo.delete_runs_with_related.assert_called_once() diff --git a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py new file mode 100644 index 0000000000..68aa8c0fe1 --- /dev/null +++ b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py @@ -0,0 +1,65 @@ +""" +Unit tests for workflow run restore functionality. +""" + +from datetime import datetime +from unittest.mock import MagicMock + + +class TestWorkflowRunRestore: + """Tests for the WorkflowRunRestore class.""" + + def test_restore_initialization(self): + """Restore service should respect dry_run flag.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + restore = WorkflowRunRestore(dry_run=True) + + assert restore.dry_run is True + + def test_convert_datetime_fields(self): + """ISO datetime strings should be converted to datetime objects.""" + from models.workflow import WorkflowRun + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + record = { + "id": "test-id", + "created_at": "2024-01-01T12:00:00", + "finished_at": "2024-01-01T12:05:00", + "name": "test", + } + + restore = WorkflowRunRestore() + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["created_at"].month == 1 + assert result["name"] == "test" + + def test_restore_table_records_returns_rowcount(self): + """Restore should return inserted rowcount.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + session = MagicMock() + session.execute.return_value = MagicMock(rowcount=2) + + restore = WorkflowRunRestore() + records = [{"id": "p1", "workflow_run_id": "r1", "created_at": "2024-01-01T00:00:00"}] + + restored = restore._restore_table_records(session, "workflow_pauses", records, schema_version="1.0") + + assert restored == 2 + session.execute.assert_called_once() + + def test_restore_table_records_unknown_table(self): + """Unknown table names should be ignored gracefully.""" + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + session = MagicMock() + + restore = WorkflowRunRestore() + restored = restore._restore_table_records(session, "unknown_table", [{"id": "x1"}], schema_version="1.0") + + assert restored == 0 + session.execute.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index bace66bec4..cb18d15084 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -49,10 +49,14 @@ def pipeline_id(): @pytest.fixture def mock_db_session(): - """Mock database session with query capabilities.""" - with patch("tasks.clean_dataset_task.db") as mock_db: + """Mock database session via session_factory.create_session().""" + with patch("tasks.clean_dataset_task.session_factory") as mock_sf: mock_session = MagicMock() - mock_db.session = mock_session + # context manager for create_session() + cm = MagicMock() + cm.__enter__.return_value = mock_session + cm.__exit__.return_value = None + mock_sf.create_session.return_value = cm # Setup query chain mock_query = MagicMock() @@ -66,7 +70,10 @@ def mock_db_session(): # Setup execute for JOIN queries mock_session.execute.return_value.all.return_value = [] - yield mock_db + # Yield an object with a `.session` attribute to keep tests unchanged + wrapper = MagicMock() + wrapper.session = mock_session + yield wrapper @pytest.fixture @@ -227,7 +234,9 @@ class TestBasicCleanup: # Assert mock_db_session.session.delete.assert_any_call(mock_document) - mock_db_session.session.delete.assert_any_call(mock_segment) + # Segments are deleted in batch; verify a DELETE on document_segments was issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) mock_db_session.session.commit.assert_called_once() def test_clean_dataset_task_deletes_related_records( @@ -413,7 +422,9 @@ class TestErrorHandling: # Assert - documents and segments should still be deleted mock_db_session.session.delete.assert_any_call(mock_document) - mock_db_session.session.delete.assert_any_call(mock_segment) + # Segments are deleted in batch; verify a DELETE on document_segments was issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) mock_db_session.session.commit.assert_called_once() def test_clean_dataset_task_storage_delete_failure_continues( @@ -461,7 +472,7 @@ class TestErrorHandling: [mock_segment], # segments ] mock_get_image_upload_file_ids.return_value = [image_file_id] - mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file] mock_storage.delete.side_effect = Exception("Storage service unavailable") # Act @@ -476,8 +487,9 @@ class TestErrorHandling: # Assert - storage delete was attempted for image file mock_storage.delete.assert_called_with(mock_upload_file.key) - # Image file should still be deleted from database - mock_db_session.session.delete.assert_any_call(mock_upload_file) + # Upload files are deleted in batch; verify a DELETE on upload_files was issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM upload_files" in sql for sql in execute_sqls) def test_clean_dataset_task_database_error_rollback( self, @@ -691,8 +703,10 @@ class TestSegmentAttachmentCleanup: # Assert mock_storage.delete.assert_called_with(mock_attachment_file.key) - mock_db_session.session.delete.assert_any_call(mock_attachment_file) - mock_db_session.session.delete.assert_any_call(mock_binding) + # Attachment file and binding are deleted in batch; verify DELETEs were issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM upload_files" in sql for sql in execute_sqls) + assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls) def test_clean_dataset_task_attachment_storage_failure( self, @@ -734,9 +748,10 @@ class TestSegmentAttachmentCleanup: # Assert - storage delete was attempted mock_storage.delete.assert_called_once() - # Records should still be deleted from database - mock_db_session.session.delete.assert_any_call(mock_attachment_file) - mock_db_session.session.delete.assert_any_call(mock_binding) + # Records are deleted in batch; verify DELETEs were issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM upload_files" in sql for sql in execute_sqls) + assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls) # ============================================================================ @@ -784,7 +799,7 @@ class TestUploadFileCleanup: [mock_document], # documents [], # segments ] - mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file + mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file] # Act clean_dataset_task( @@ -798,7 +813,9 @@ class TestUploadFileCleanup: # Assert mock_storage.delete.assert_called_with(mock_upload_file.key) - mock_db_session.session.delete.assert_any_call(mock_upload_file) + # Upload files are deleted in batch; verify a DELETE on upload_files was issued + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM upload_files" in sql for sql in execute_sqls) def test_clean_dataset_task_handles_missing_upload_file( self, @@ -832,7 +849,7 @@ class TestUploadFileCleanup: [mock_document], # documents [], # segments ] - mock_db_session.session.query.return_value.where.return_value.first.return_value = None + mock_db_session.session.query.return_value.where.return_value.all.return_value = [] # Act - should not raise exception clean_dataset_task( @@ -949,11 +966,11 @@ class TestImageFileCleanup: [mock_segment], # segments ] - # Setup a mock query chain that returns files in sequence + # Setup a mock query chain that returns files in batch (align with .in_().all()) mock_query = MagicMock() mock_where = MagicMock() mock_query.where.return_value = mock_where - mock_where.first.side_effect = mock_image_files + mock_where.all.return_value = mock_image_files mock_db_session.session.query.return_value = mock_query # Act @@ -966,10 +983,10 @@ class TestImageFileCleanup: doc_form="paragraph_index", ) - # Assert - assert mock_storage.delete.call_count == 2 - mock_storage.delete.assert_any_call("images/image-1.jpg") - mock_storage.delete.assert_any_call("images/image-2.jpg") + # Assert - each expected image key was deleted at least once + calls = [c.args[0] for c in mock_storage.delete.call_args_list] + assert "images/image-1.jpg" in calls + assert "images/image-2.jpg" in calls def test_clean_dataset_task_handles_missing_image_file( self, @@ -1010,7 +1027,7 @@ class TestImageFileCleanup: ] # Image file not found - mock_db_session.session.query.return_value.where.return_value.first.return_value = None + mock_db_session.session.query.return_value.where.return_value.all.return_value = [] # Act - should not raise exception clean_dataset_task( @@ -1086,14 +1103,15 @@ class TestEdgeCases: doc_form="paragraph_index", ) - # Assert - all documents and segments should be deleted + # Assert - all documents and segments should be deleted (documents per-entity, segments in batch) delete_calls = mock_db_session.session.delete.call_args_list deleted_items = [call[0][0] for call in delete_calls] for doc in mock_documents: assert doc in deleted_items - for seg in mock_segments: - assert seg in deleted_items + # Verify a batch DELETE on document_segments occurred + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) def test_clean_dataset_task_document_with_empty_data_source_info( self, diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 9d7599b8fe..e24ef32a24 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -81,12 +81,25 @@ def mock_documents(document_ids, dataset_id): @pytest.fixture def mock_db_session(): - """Mock database session.""" - with patch("tasks.document_indexing_task.db.session") as mock_session: - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - yield mock_session + """Mock database session via session_factory.create_session().""" + with patch("tasks.document_indexing_task.session_factory") as mock_sf: + session = MagicMock() + # Ensure tests that expect session.close() to be called can observe it via the context manager + session.close = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + # Link __exit__ to session.close so "close" expectations reflect context manager teardown + + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + mock_sf.create_session.return_value = cm + + query = MagicMock() + session.query.return_value = query + query.where.return_value = query + yield session @pytest.fixture diff --git a/api/tests/unit_tests/tasks/test_delete_account_task.py b/api/tests/unit_tests/tasks/test_delete_account_task.py index 3b148e63f2..8a12a4a169 100644 --- a/api/tests/unit_tests/tasks/test_delete_account_task.py +++ b/api/tests/unit_tests/tasks/test_delete_account_task.py @@ -18,12 +18,18 @@ from tasks.delete_account_task import delete_account_task @pytest.fixture def mock_db_session(): - """Mock the db.session used in delete_account_task.""" - with patch("tasks.delete_account_task.db.session") as mock_session: - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - yield mock_session + """Mock session via session_factory.create_session().""" + with patch("tasks.delete_account_task.session_factory") as mock_sf: + session = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + mock_sf.create_session.return_value = cm + + query = MagicMock() + session.query.return_value = query + query.where.return_value = query + yield session @pytest.fixture diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 374abe0368..fa33034f40 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -109,13 +109,25 @@ def mock_document_segments(document_id): @pytest.fixture def mock_db_session(): - """Mock database session.""" - with patch("tasks.document_indexing_sync_task.db.session") as mock_session: - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_session.scalars.return_value = MagicMock() - yield mock_session + """Mock database session via session_factory.create_session().""" + with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf: + session = MagicMock() + # Ensure tests can observe session.close() via context manager teardown + session.close = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + mock_sf.create_session.return_value = cm + + query = MagicMock() + session.query.return_value = query + query.where.return_value = query + session.scalars.return_value = MagicMock() + yield session @pytest.fixture @@ -251,8 +263,8 @@ class TestDocumentIndexingSyncTask: # Assert # Document status should remain unchanged assert mock_document.indexing_status == "completed" - # No session operations should be performed beyond the initial query - mock_db_session.close.assert_not_called() + # Session should still be closed via context manager teardown + assert mock_db_session.close.called def test_successful_sync_when_page_updated( self, @@ -286,9 +298,9 @@ class TestDocumentIndexingSyncTask: mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_processor.clean.assert_called_once() - # Verify segments were deleted from database - for segment in mock_document_segments: - mock_db_session.delete.assert_any_call(segment) + # Verify segments were deleted from database in batch (DELETE FROM document_segments) + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) # Verify indexing runner was called mock_indexing_runner.run.assert_called_once_with([mock_document]) diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py index 0be6ea045e..8a4c6da2e9 100644 --- a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -94,13 +94,25 @@ def mock_document_segments(document_ids): @pytest.fixture def mock_db_session(): - """Mock database session.""" - with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session: - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_session.scalars.return_value = MagicMock() - yield mock_session + """Mock database session via session_factory.create_session().""" + with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf: + session = MagicMock() + # Allow tests to observe session.close() via context manager teardown + session.close = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + mock_sf.create_session.return_value = cm + + query = MagicMock() + session.query.return_value = query + query.where.return_value = query + session.scalars.return_value = MagicMock() + yield session @pytest.fixture @@ -200,8 +212,25 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test successful duplicate document indexing flow.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = mock_document_segments + # Dataset via query.first() + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + # scalars() call sequence: + # 1) documents list + # 2..N) segments per document + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + # First call returns documents; subsequent calls return segments + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = mock_document_segments + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect # Act _duplicate_document_indexing_task(dataset_id, document_ids) @@ -264,8 +293,21 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test duplicate document indexing when billing limit is exceeded.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + # First scalars() -> documents; subsequent -> empty segments + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = [] + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect mock_features = mock_feature_service.get_features.return_value mock_features.billing.enabled = True mock_features.billing.subscription.plan = CloudPlan.TEAM @@ -294,8 +336,20 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test duplicate document indexing when IndexingRunner raises an error.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = [] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = [] + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect mock_indexing_runner.run.side_effect = Exception("Indexing error") # Act @@ -318,8 +372,20 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test duplicate document indexing when document is paused.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = [] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = [] + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") # Act @@ -343,8 +409,20 @@ class TestDuplicateDocumentIndexingTaskCore: ): """Test that duplicate document indexing cleans old segments.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents - mock_db_session.scalars.return_value.all.return_value = mock_document_segments + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + def _scalars_side_effect(*args, **kwargs): + m = MagicMock() + if not hasattr(_scalars_side_effect, "_calls"): + _scalars_side_effect._calls = 0 + if _scalars_side_effect._calls == 0: + m.all.return_value = mock_documents + else: + m.all.return_value = mock_document_segments + _scalars_side_effect._calls += 1 + return m + + mock_db_session.scalars.side_effect = _scalars_side_effect mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value # Act @@ -354,9 +432,9 @@ class TestDuplicateDocumentIndexingTaskCore: # Verify clean was called for each document assert mock_processor.clean.call_count == len(mock_documents) - # Verify segments were deleted - for segment in mock_document_segments: - mock_db_session.delete.assert_any_call(segment) + # Verify segments were deleted in batch (DELETE FROM document_segments) + execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] + assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) # ============================================================================ diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index 1fe77c2935..a14bbb01d0 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,7 +2,11 @@ from unittest.mock import ANY, MagicMock, call, patch import pytest +from libs.archive_storage import ArchiveStorageNotConfiguredError +from models.workflow import WorkflowArchiveLog from tasks.remove_app_and_related_data_task import ( + _delete_app_workflow_archive_logs, + _delete_archived_workflow_run_files, _delete_draft_variable_offload_data, _delete_draft_variables, delete_draft_variables_batch, @@ -11,21 +15,18 @@ from tasks.remove_app_and_related_data_task import ( class TestDeleteDraftVariablesBatch: @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") - @patch("tasks.remove_app_and_related_data_task.db") - def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup): + @patch("tasks.remove_app_and_related_data_task.session_factory") + def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup): """Test successful deletion of draft variables in batches.""" app_id = "test-app-id" batch_size = 100 - # Mock database connection and engine - mock_conn = MagicMock() - mock_engine = MagicMock() - mock_db.engine = mock_engine - # Properly mock the context manager + # Mock session via session_factory + mock_session = MagicMock() mock_context_manager = MagicMock() - mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__enter__.return_value = mock_session mock_context_manager.__exit__.return_value = None - mock_engine.begin.return_value = mock_context_manager + mock_sf.create_session.return_value = mock_context_manager # Mock two batches of results, then empty batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)] @@ -68,7 +69,7 @@ class TestDeleteDraftVariablesBatch: select_result3.__iter__.return_value = iter([]) # Configure side effects in the correct order - mock_conn.execute.side_effect = [ + mock_session.execute.side_effect = [ select_result1, # First SELECT delete_result1, # First DELETE select_result2, # Second SELECT @@ -86,54 +87,49 @@ class TestDeleteDraftVariablesBatch: assert result == 150 # Verify database calls - assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes + assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes # Verify offload cleanup was called for both batches with file_ids - expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)] + expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)] mock_offload_cleanup.assert_has_calls(expected_offload_calls) # Simplified verification - check that the right number of calls were made # and that the SQL queries contain the expected patterns - actual_calls = mock_conn.execute.call_args_list + actual_calls = mock_session.execute.call_args_list for i, actual_call in enumerate(actual_calls): + sql_text = str(actual_call[0][0]) + normalized = " ".join(sql_text.split()) if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4) - # Verify it's a SELECT query that now includes file_id - sql_text = str(actual_call[0][0]) - assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text - assert "WHERE app_id = :app_id" in sql_text - assert "LIMIT :batch_size" in sql_text + assert "SELECT id, file_id FROM workflow_draft_variables" in normalized + assert "WHERE app_id = :app_id" in normalized + assert "LIMIT :batch_size" in normalized else: # DELETE calls (odd indices: 1, 3) - # Verify it's a DELETE query - sql_text = str(actual_call[0][0]) - assert "DELETE FROM workflow_draft_variables" in sql_text - assert "WHERE id IN :ids" in sql_text + assert "DELETE FROM workflow_draft_variables" in normalized + assert "WHERE id IN :ids" in normalized @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") - @patch("tasks.remove_app_and_related_data_task.db") - def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup): + @patch("tasks.remove_app_and_related_data_task.session_factory") + def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup): """Test deletion when no draft variables exist for the app.""" app_id = "nonexistent-app-id" batch_size = 1000 - # Mock database connection - mock_conn = MagicMock() - mock_engine = MagicMock() - mock_db.engine = mock_engine - # Properly mock the context manager + # Mock session via session_factory + mock_session = MagicMock() mock_context_manager = MagicMock() - mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__enter__.return_value = mock_session mock_context_manager.__exit__.return_value = None - mock_engine.begin.return_value = mock_context_manager + mock_sf.create_session.return_value = mock_context_manager # Mock empty result empty_result = MagicMock() empty_result.__iter__.return_value = iter([]) - mock_conn.execute.return_value = empty_result + mock_session.execute.return_value = empty_result result = delete_draft_variables_batch(app_id, batch_size) assert result == 0 - assert mock_conn.execute.call_count == 1 # Only one select query + assert mock_session.execute.call_count == 1 # Only one select query mock_offload_cleanup.assert_not_called() # No files to clean up def test_delete_draft_variables_batch_invalid_batch_size(self): @@ -147,22 +143,19 @@ class TestDeleteDraftVariablesBatch: delete_draft_variables_batch(app_id, 0) @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") - @patch("tasks.remove_app_and_related_data_task.db") + @patch("tasks.remove_app_and_related_data_task.session_factory") @patch("tasks.remove_app_and_related_data_task.logger") - def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup): + def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup): """Test that batch deletion logs progress correctly.""" app_id = "test-app-id" batch_size = 50 - # Mock database - mock_conn = MagicMock() - mock_engine = MagicMock() - mock_db.engine = mock_engine - # Properly mock the context manager + # Mock session via session_factory + mock_session = MagicMock() mock_context_manager = MagicMock() - mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__enter__.return_value = mock_session mock_context_manager.__exit__.return_value = None - mock_engine.begin.return_value = mock_context_manager + mock_sf.create_session.return_value = mock_context_manager # Mock one batch then empty batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)] @@ -183,7 +176,7 @@ class TestDeleteDraftVariablesBatch: empty_result = MagicMock() empty_result.__iter__.return_value = iter([]) - mock_conn.execute.side_effect = [ + mock_session.execute.side_effect = [ # Select query result select_result, # Delete query result @@ -201,7 +194,7 @@ class TestDeleteDraftVariablesBatch: # Verify offload cleanup was called with file_ids if batch_file_ids: - mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids) + mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids) # Verify logging calls assert mock_logging.info.call_count == 2 @@ -261,19 +254,19 @@ class TestDeleteDraftVariableOffloadData: actual_calls = mock_conn.execute.call_args_list # First call should be the SELECT query - select_call_sql = str(actual_calls[0][0][0]) + select_call_sql = " ".join(str(actual_calls[0][0][0]).split()) assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql assert "FROM workflow_draft_variable_files wdvf" in select_call_sql assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql assert "WHERE wdvf.id IN :file_ids" in select_call_sql # Second call should be DELETE upload_files - delete_upload_call_sql = str(actual_calls[1][0][0]) + delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split()) assert "DELETE FROM upload_files" in delete_upload_call_sql assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql # Third call should be DELETE workflow_draft_variable_files - delete_variable_files_call_sql = str(actual_calls[2][0][0]) + delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split()) assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql assert "WHERE id IN :file_ids" in delete_variable_files_call_sql @@ -335,3 +328,68 @@ class TestDeleteDraftVariableOffloadData: # Verify error was logged mock_logging.exception.assert_called_once_with("Error deleting draft variable offload data:") + + +class TestDeleteWorkflowArchiveLogs: + @patch("tasks.remove_app_and_related_data_task._delete_records") + @patch("tasks.remove_app_and_related_data_task.db") + def test_delete_app_workflow_archive_logs_calls_delete_records(self, mock_db, mock_delete_records): + tenant_id = "tenant-1" + app_id = "app-1" + + _delete_app_workflow_archive_logs(tenant_id, app_id) + + mock_delete_records.assert_called_once() + query_sql, params, delete_func, name = mock_delete_records.call_args[0] + assert "workflow_archive_logs" in query_sql + assert params == {"tenant_id": tenant_id, "app_id": app_id} + assert name == "workflow archive log" + + mock_query = MagicMock() + mock_delete_query = MagicMock() + mock_query.where.return_value = mock_delete_query + mock_db.session.query.return_value = mock_query + + delete_func("log-1") + + mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) + mock_query.where.assert_called_once() + mock_delete_query.delete.assert_called_once_with(synchronize_session=False) + + +class TestDeleteArchivedWorkflowRunFiles: + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_not_configured(self, mock_logger, mock_get_storage): + mock_get_storage.side_effect = ArchiveStorageNotConfiguredError("missing config") + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + assert mock_logger.info.call_count == 1 + assert "Archive storage not configured" in mock_logger.info.call_args[0][0] + + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_list_failure(self, mock_logger, mock_get_storage): + storage = MagicMock() + storage.list_objects.side_effect = Exception("list failed") + mock_get_storage.return_value = storage + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + storage.list_objects.assert_called_once_with("tenant-1/app_id=app-1/") + storage.delete_object.assert_not_called() + mock_logger.exception.assert_called_once_with("Failed to list archive files for app %s", "app-1") + + @patch("tasks.remove_app_and_related_data_task.get_archive_storage") + @patch("tasks.remove_app_and_related_data_task.logger") + def test_delete_archived_workflow_run_files_success(self, mock_logger, mock_get_storage): + storage = MagicMock() + storage.list_objects.return_value = ["key-1", "key-2"] + mock_get_storage.return_value = storage + + _delete_archived_workflow_run_files("tenant-1", "app-1") + + storage.list_objects.assert_called_once_with("tenant-1/app_id=app-1/") + storage.delete_object.assert_has_calls([call("key-1"), call("key-2")], any_order=False) + mock_logger.info.assert_called_with("Deleted %s archive objects for app %s", 2, "app-1") diff --git a/dev/start-web b/dev/start-web index 31c5e168f9..f853f4a895 100755 --- a/dev/start-web +++ b/dev/start-web @@ -5,4 +5,4 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../web" -pnpm install && pnpm dev +pnpm install && pnpm dev:inspect diff --git a/docker/.env.example b/docker/.env.example index 627a3a23da..c7246ae11f 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1518,3 +1518,4 @@ AMPLITUDE_API_KEY= SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 +SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 429667e75f..902ca3103c 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -682,6 +682,7 @@ x-shared-env: &shared-api-worker-env SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: ${SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD:-21} SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} + SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: ${SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL:-90000} services: # Init container to fix permissions diff --git a/web/README.md b/web/README.md index 13780eec6c..9c731a081a 100644 --- a/web/README.md +++ b/web/README.md @@ -138,7 +138,7 @@ This will help you determine the testing strategy. See [web/testing/testing.md]( ## Documentation -Visit to view the full documentation. +Visit to view the full documentation. ## Community diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index 81b4f2474e..f07b2932c9 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -5,7 +5,6 @@ import type { BlockEnum } from '@/app/components/workflow/types' import type { UpdateAppSiteCodeResponse } from '@/models/app' import type { App } from '@/types/app' import type { I18nKeysByPrefix } from '@/types/i18n' -import * as React from 'react' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -17,7 +16,6 @@ import { ToastContext } from '@/app/components/base/toast' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' import { isTriggerNode } from '@/app/components/workflow/types' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' -import { useDocLink } from '@/context/i18n' import { fetchAppDetail, updateAppSiteAccessToken, @@ -36,7 +34,6 @@ export type ICardViewProps = { const CardView: FC = ({ appId, isInPanel, className }) => { const { t } = useTranslation() - const docLink = useDocLink() const { notify } = useContext(ToastContext) const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) @@ -59,25 +56,13 @@ const CardView: FC = ({ appId, isInPanel, className }) => { const shouldRenderAppCards = !isWorkflowApp || hasTriggerNode === false const disableAppCards = !shouldRenderAppCards - const triggerDocUrl = docLink('/guides/workflow/node/start') const buildTriggerModeMessage = useCallback((featureName: string) => (
{t('overview.disableTooltip.triggerMode', { ns: 'appOverview', feature: featureName })}
- { - event.stopPropagation() - }} - > - {t('overview.appInfo.enableTooltip.learnMore', { ns: 'appOverview' })} -
- ), [t, triggerDocUrl]) + ), [t]) const disableWebAppTooltip = disableAppCards ? buildTriggerModeMessage(t('overview.appInfo.title', { ns: 'appOverview' })) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index 28489a6714..6d5eb1ef95 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -48,7 +48,7 @@ const CSVUploader: FC = ({ setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 1) { notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) return diff --git a/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx index 60627e12c2..827986f521 100644 --- a/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx +++ b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.spec.tsx @@ -1,12 +1,6 @@ import { render, screen } from '@testing-library/react' -import * as React from 'react' import HistoryPanel from './history-panel' -const mockDocLink = vi.fn(() => 'doc-link') -vi.mock('@/context/i18n', () => ({ - useDocLink: () => mockDocLink, -})) - vi.mock('@/app/components/app/configuration/base/operation-btn', () => ({ default: ({ onClick }: { onClick: () => void }) => ( + {canNotRun && ( + {isTrialApp && ( + + )} )} diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index 6e8c94aea6..e2b50cf030 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -5,7 +5,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiCommandLine, RiCornerDownLeftLin import { useDebounceFn, useKeyPress } from 'ahooks' import Image from 'next/image' -import Link from 'next/link' import { useRouter } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -22,7 +21,6 @@ import { ToastContext } from '@/app/components/base/toast' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' -import { useDocLink } from '@/context/i18n' import { useProviderContext } from '@/context/provider-context' import useTheme from '@/hooks/use-theme' import { createApp } from '@/service/apps' @@ -346,41 +344,26 @@ function AppTypeCard({ icon, title, description, active, onClick }: AppTypeCardP function AppPreview({ mode }: { mode: AppModeEnum }) { const { t } = useTranslation() - const docLink = useDocLink() const modeToPreviewInfoMap = { [AppModeEnum.CHAT]: { title: t('types.chatbot', { ns: 'app' }), description: t('newApp.chatbotUserDescription', { ns: 'app' }), - link: docLink('/guides/application-orchestrate/chatbot-application'), }, [AppModeEnum.ADVANCED_CHAT]: { title: t('types.advanced', { ns: 'app' }), description: t('newApp.advancedUserDescription', { ns: 'app' }), - link: docLink('/guides/workflow/README', { - 'zh-Hans': '/guides/workflow/readme', - 'ja-JP': '/guides/workflow/concepts', - }), }, [AppModeEnum.AGENT_CHAT]: { title: t('types.agent', { ns: 'app' }), description: t('newApp.agentUserDescription', { ns: 'app' }), - link: docLink('/guides/application-orchestrate/agent'), }, [AppModeEnum.COMPLETION]: { title: t('newApp.completeApp', { ns: 'app' }), description: t('newApp.completionUserDescription', { ns: 'app' }), - link: docLink('/guides/application-orchestrate/text-generator', { - 'zh-Hans': '/guides/application-orchestrate/readme', - 'ja-JP': '/guides/application-orchestrate/README', - }), }, [AppModeEnum.WORKFLOW]: { title: t('types.workflow', { ns: 'app' }), description: t('newApp.workflowUserDescription', { ns: 'app' }), - link: docLink('/guides/workflow/README', { - 'zh-Hans': '/guides/workflow/readme', - 'ja-JP': '/guides/workflow/concepts', - }), }, } const previewInfo = modeToPreviewInfoMap[mode] @@ -389,7 +372,6 @@ function AppPreview({ mode }: { mode: AppModeEnum }) {

{previewInfo.title}

{previewInfo.description} - {previewInfo.link && {t('newApp.learnMore', { ns: 'app' })}}
) diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index 133bd34dbc..778a2c1420 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -58,7 +58,7 @@ const Uploader: FC = ({ setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 1) { notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) return diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 16f67de547..4fbc2f11d7 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -39,6 +39,7 @@ import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useTimestamp from '@/hooks/use-timestamp' import { fetchChatMessages, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log' +import { AppSourceType } from '@/service/share' import { useChatConversationDetail, useCompletionConversationDetail } from '@/service/use-log' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' @@ -647,12 +648,12 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { item.from_source === 'admin')} onFeedback={feedback => onFeedback(detail.message.id, feedback)} diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index c1a662df5d..9975c81b3e 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -245,7 +245,7 @@ function AppCard({
window.open(docLink('/guides/workflow/node/user-input'), '_blank')} + onClick={() => window.open(docLink('/use-dify/nodes/user-input'), '_blank')} > {t('overview.appInfo.enableTooltip.learnMore', { ns: 'appOverview' })}
diff --git a/web/app/components/app/overview/customize/index.spec.tsx b/web/app/components/app/overview/customize/index.spec.tsx index 1cfe5a7130..e1bb7e938d 100644 --- a/web/app/components/app/overview/customize/index.spec.tsx +++ b/web/app/components/app/overview/customize/index.spec.tsx @@ -305,7 +305,7 @@ describe('CustomizeModal', () => { // Assert expect(mockWindowOpen).toHaveBeenCalledTimes(1) expect(mockWindowOpen).toHaveBeenCalledWith( - expect.stringContaining('/guides/application-publishing/developing-with-apis'), + expect.stringContaining('/use-dify/publish/developing-with-apis'), '_blank', ) }) diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index 77dae81a01..c7391abe3d 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -118,7 +118,7 @@ const CustomizeModal: FC = ({ className="mt-2" onClick={() => window.open( - docLink('/guides/application-publishing/developing-with-apis'), + docLink('/use-dify/publish/developing-with-apis'), '_blank', )} > diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 428a475da9..0d087e27c2 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -23,7 +23,6 @@ import Textarea from '@/app/components/base/textarea' import { useToastContext } from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' -import { useDocLink } from '@/context/i18n' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' import { languages } from '@/i18n-config/language' @@ -100,7 +99,6 @@ const SettingsModal: FC = ({ const [language, setLanguage] = useState(default_language) const [saveLoading, setSaveLoading] = useState(false) const { t } = useTranslation() - const docLink = useDocLink() const [showAppIconPicker, setShowAppIconPicker] = useState(false) const [appIcon, setAppIcon] = useState( @@ -240,16 +238,6 @@ const SettingsModal: FC = ({
{t(`${prefixSettings}.modalTip`, { ns: 'appOverview' })} - - {t('operation.learnMore', { ns: 'common' })} -
{/* form body */} diff --git a/web/app/components/app/overview/trigger-card.tsx b/web/app/components/app/overview/trigger-card.tsx index a2d28606a1..12a294b4ec 100644 --- a/web/app/components/app/overview/trigger-card.tsx +++ b/web/app/components/app/overview/trigger-card.tsx @@ -208,7 +208,7 @@ function TriggerCard({ appInfo, onToggleResult }: ITriggerCardProps) { {t('overview.triggerInfo.triggerStatusDescription', { ns: 'appOverview' })} {' '} void onSave?: (messageId: string) => void isMobile?: boolean - isInstalledApp: boolean + appSourceType: AppSourceType installedAppId?: string taskId?: string controlClearMoreLikeThis?: number @@ -90,7 +90,7 @@ const GenerationItem: FC = ({ onSave, depth = 1, isMobile, - isInstalledApp, + appSourceType, installedAppId, taskId, controlClearMoreLikeThis, @@ -103,6 +103,7 @@ const GenerationItem: FC = ({ const { t } = useTranslation() const params = useParams() const isTop = depth === 1 + const isTryApp = appSourceType === AppSourceType.tryApp const [completionRes, setCompletionRes] = useState('') const [childMessageId, setChildMessageId] = useState(null) const [childFeedback, setChildFeedback] = useState({ @@ -116,7 +117,7 @@ const GenerationItem: FC = ({ const setShowPromptLogModal = useAppStore(s => s.setShowPromptLogModal) const handleFeedback = async (childFeedback: FeedbackType) => { - await updateFeedback({ url: `/messages/${childMessageId}/feedbacks`, body: { rating: childFeedback.rating } }, isInstalledApp, installedAppId) + await updateFeedback({ url: `/messages/${childMessageId}/feedbacks`, body: { rating: childFeedback.rating } }, appSourceType, installedAppId) setChildFeedback(childFeedback) } @@ -134,7 +135,7 @@ const GenerationItem: FC = ({ onSave, isShowTextToSpeech, isMobile, - isInstalledApp, + appSourceType, installedAppId, controlClearMoreLikeThis, isWorkflow, @@ -148,7 +149,7 @@ const GenerationItem: FC = ({ return } startQuerying() - const res: any = await fetchMoreLikeThis(messageId as string, isInstalledApp, installedAppId) + const res: any = await fetchMoreLikeThis(messageId as string, appSourceType, installedAppId) setCompletionRes(res.answer) setChildFeedback({ rating: null, @@ -336,7 +337,7 @@ const GenerationItem: FC = ({ )} {/* action buttons */}
- {!isInWebApp && !isInstalledApp && !isResponding && ( + {!isInWebApp && (appSourceType !== AppSourceType.installedApp) && !isResponding && (
@@ -345,12 +346,12 @@ const GenerationItem: FC = ({
)}
- {moreLikeThis && ( + {moreLikeThis && !isTryApp && ( )} - {isShowTextToSpeech && ( + {isShowTextToSpeech && !isTryApp && ( = ({ )} - {isInWebApp && !isWorkflow && ( + {isInWebApp && !isWorkflow && !isTryApp && ( { onSave?.(messageId as string) }}> )}
- {(supportFeedback || isInWebApp) && !isWorkflow && !isError && messageId && ( + {(supportFeedback || isInWebApp) && !isWorkflow && !isTryApp && !isError && messageId && (
{!feedback?.rating && ( <> diff --git a/web/app/components/apps/hooks/use-dsl-drag-drop.ts b/web/app/components/apps/hooks/use-dsl-drag-drop.ts index dda5773062..77d89b87da 100644 --- a/web/app/components/apps/hooks/use-dsl-drag-drop.ts +++ b/web/app/components/apps/hooks/use-dsl-drag-drop.ts @@ -36,7 +36,7 @@ export const useDSLDragDrop = ({ onDSLFileDropped, containerRef, enabled = true if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length === 0) return diff --git a/web/app/components/apps/index.spec.tsx b/web/app/components/apps/index.spec.tsx index c3dc39955d..c77c1bdb01 100644 --- a/web/app/components/apps/index.spec.tsx +++ b/web/app/components/apps/index.spec.tsx @@ -1,3 +1,5 @@ +import type { ReactNode } from 'react' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { render, screen } from '@testing-library/react' import * as React from 'react' @@ -22,6 +24,15 @@ vi.mock('@/app/education-apply/hooks', () => ({ }, })) +vi.mock('@/hooks/use-import-dsl', () => ({ + useImportDSL: () => ({ + handleImportDSL: vi.fn(), + handleImportDSLConfirm: vi.fn(), + versions: [], + isFetching: false, + }), +})) + // Mock List component vi.mock('./list', () => ({ default: () => { @@ -30,6 +41,25 @@ vi.mock('./list', () => ({ })) describe('Apps', () => { + const createQueryClient = () => new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }) + + const renderWithClient = (ui: React.ReactElement) => { + const queryClient = createQueryClient() + const wrapper = ({ children }: { children: ReactNode }) => ( + {children} + ) + return { + queryClient, + ...render(ui, { wrapper }), + } + } + beforeEach(() => { vi.clearAllMocks() documentTitleCalls = [] @@ -38,17 +68,17 @@ describe('Apps', () => { describe('Rendering', () => { it('should render without crashing', () => { - render() + renderWithClient() expect(screen.getByTestId('apps-list')).toBeInTheDocument() }) it('should render List component', () => { - render() + renderWithClient() expect(screen.getByText('Apps List')).toBeInTheDocument() }) it('should have correct container structure', () => { - const { container } = render() + const { container } = renderWithClient() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('relative', 'flex', 'h-0', 'shrink-0', 'grow', 'flex-col') }) @@ -56,19 +86,19 @@ describe('Apps', () => { describe('Hooks', () => { it('should call useDocumentTitle with correct title', () => { - render() + renderWithClient() expect(documentTitleCalls).toContain('common.menus.apps') }) it('should call useEducationInit', () => { - render() + renderWithClient() expect(educationInitCalls).toBeGreaterThan(0) }) }) describe('Integration', () => { it('should render full component tree', () => { - render() + renderWithClient() // Verify container exists expect(screen.getByTestId('apps-list')).toBeInTheDocument() @@ -79,23 +109,32 @@ describe('Apps', () => { }) it('should handle multiple renders', () => { - const { rerender } = render() + const queryClient = createQueryClient() + const { rerender } = render( + + + , + ) expect(screen.getByTestId('apps-list')).toBeInTheDocument() - rerender() + rerender( + + + , + ) expect(screen.getByTestId('apps-list')).toBeInTheDocument() }) }) describe('Styling', () => { it('should have overflow-y-auto class', () => { - const { container } = render() + const { container } = renderWithClient() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('overflow-y-auto') }) it('should have background styling', () => { - const { container } = render() + const { container } = renderWithClient() const wrapper = container.firstChild as HTMLElement expect(wrapper).toHaveClass('bg-background-body') }) diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index b151df1e1f..255bfbf9c5 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -1,7 +1,17 @@ 'use client' +import type { CreateAppModalProps } from '../explore/create-app-modal' +import type { CurrentTryAppParams } from '@/context/explore-context' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useEducationInit } from '@/app/education-apply/hooks' +import AppListContext from '@/context/app-list-context' import useDocumentTitle from '@/hooks/use-document-title' +import { useImportDSL } from '@/hooks/use-import-dsl' +import { DSLImportMode } from '@/models/app' +import { fetchAppDetail } from '@/service/explore' +import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal' +import CreateAppModal from '../explore/create-app-modal' +import TryApp from '../explore/try-app' import List from './list' const Apps = () => { @@ -10,10 +20,124 @@ const Apps = () => { useDocumentTitle(t('menus.apps', { ns: 'common' })) useEducationInit() + const [currentTryAppParams, setCurrentTryAppParams] = useState(undefined) + const currApp = currentTryAppParams?.app + const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false) + const hideTryAppPanel = useCallback(() => { + setIsShowTryAppPanel(false) + }, []) + const setShowTryAppPanel = (showTryAppPanel: boolean, params?: CurrentTryAppParams) => { + if (showTryAppPanel) + setCurrentTryAppParams(params) + else + setCurrentTryAppParams(undefined) + setIsShowTryAppPanel(showTryAppPanel) + } + const [isShowCreateModal, setIsShowCreateModal] = useState(false) + + const handleShowFromTryApp = useCallback(() => { + setIsShowCreateModal(true) + }, []) + + const [controlRefreshList, setControlRefreshList] = useState(0) + const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0) + const onSuccess = useCallback(() => { + setControlRefreshList(prev => prev + 1) + setControlHideCreateFromTemplatePanel(prev => prev + 1) + }, []) + + const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) + + const { + handleImportDSL, + handleImportDSLConfirm, + versions, + isFetching, + } = useImportDSL() + + const onConfirmDSL = useCallback(async () => { + await handleImportDSLConfirm({ + onSuccess, + }) + }, [handleImportDSLConfirm, onSuccess]) + + const onCreate: CreateAppModalProps['onConfirm'] = async ({ + name, + icon_type, + icon, + icon_background, + description, + }) => { + hideTryAppPanel() + + const { export_data } = await fetchAppDetail( + currApp?.app.id as string, + ) + const payload = { + mode: DSLImportMode.YAML_CONTENT, + yaml_content: export_data, + name, + icon_type, + icon, + icon_background, + description, + } + await handleImportDSL(payload, { + onSuccess: () => { + setIsShowCreateModal(false) + }, + onPending: () => { + setShowDSLConfirmModal(true) + }, + }) + } + return ( -
- -
+ +
+ + {isShowTryAppPanel && ( + + )} + + { + showDSLConfirmModal && ( + setShowDSLConfirmModal(false)} + onConfirm={onConfirmDSL} + confirmDisabled={isFetching} + /> + ) + } + + {isShowCreateModal && ( + setIsShowCreateModal(false)} + /> + )} +
+
) } diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 8a236fe260..6bf79b7338 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -1,5 +1,6 @@ 'use client' +import type { FC } from 'react' import { RiApps2Line, RiDragDropLine, @@ -53,7 +54,12 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro ssr: false, }) -const List = () => { +type Props = { + controlRefreshList?: number +} +const List: FC = ({ + controlRefreshList = 0, +}) => { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() const router = useRouter() @@ -110,6 +116,13 @@ const List = () => { refetch, } = useInfiniteAppList(appListQueryParams, { enabled: !isCurrentWorkspaceDatasetOperator }) + useEffect(() => { + if (controlRefreshList > 0) { + refetch() + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [controlRefreshList]) + const anchorRef = useRef(null) const options = [ { value: 'all', text: t('types.all', { ns: 'app' }), icon: }, diff --git a/web/app/components/apps/new-app-card.tsx b/web/app/components/apps/new-app-card.tsx index bfa7af3892..868da0dcb5 100644 --- a/web/app/components/apps/new-app-card.tsx +++ b/web/app/components/apps/new-app-card.tsx @@ -6,10 +6,12 @@ import { useSearchParams, } from 'next/navigation' import * as React from 'react' -import { useMemo, useState } from 'react' +import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useContextSelector } from 'use-context-selector' import { CreateFromDSLModalTab } from '@/app/components/app/create-from-dsl-modal' import { FileArrow01, FilePlus01, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' +import AppListContext from '@/context/app-list-context' import { useProviderContext } from '@/context/provider-context' import { cn } from '@/utils/classnames' @@ -55,6 +57,13 @@ const CreateAppCard = ({ return undefined }, [dslUrl]) + const controlHideCreateFromTemplatePanel = useContextSelector(AppListContext, ctx => ctx.controlHideCreateFromTemplatePanel) + useEffect(() => { + if (controlHideCreateFromTemplatePanel > 0) + // eslint-disable-next-line react-hooks-extra/no-direct-set-state-in-use-effect + setShowNewAppTemplateDialog(false) + }, [controlHideCreateFromTemplatePanel]) + return (
{ +const ActionButton = ({ className, size, state = ActionButtonState.Default, styleCss, children, ref, disabled, ...props }: ActionButtonProps) => { return ( + ) + }, +) +CarouselPrevious.displayName = 'CarouselPrevious' + +const CarouselNext = React.forwardRef( + ({ children, ...props }, ref) => { + const { scrollNext, canScrollNext } = useCarousel() + + return ( + + ) + }, +) +CarouselNext.displayName = 'CarouselNext' + +const CarouselDot = React.forwardRef( + ({ children, ...props }, ref) => { + const { api, selectedIndex } = useCarousel() + + return api?.slideNodes().map((_, index) => { + return ( + + ) + }) + }, +) +CarouselDot.displayName = 'CarouselDot' + +const CarouselPlugins = { + Autoplay, +} + +Carousel.Content = CarouselContent +Carousel.Item = CarouselItem +Carousel.Previous = CarouselPrevious +Carousel.Next = CarouselNext +Carousel.Dot = CarouselDot +Carousel.Plugin = CarouselPlugins + +export { Carousel, useCarousel } diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index ae5b9e35b0..20717eab96 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -13,6 +13,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested import { Markdown } from '@/app/components/base/markdown' import { InputVarType } from '@/app/components/workflow/types' import { + AppSourceType, fetchSuggestedQuestions, getUrl, stopChatMessageResponding, @@ -55,6 +56,11 @@ const ChatWrapper = () => { initUserVariables, } = useChatWithHistoryContext() + const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp + + // Semantic variable for better code readability + const isHistoryConversation = !!currentConversationId + const appConfig = useMemo(() => { const config = appParams || {} @@ -82,7 +88,7 @@ const ChatWrapper = () => { inputsForm: inputsForms, }, appPrevChatTree, - taskId => stopChatMessageResponding('', taskId, isInstalledApp, appId), + taskId => stopChatMessageResponding('', taskId, appSourceType, appId), clearChatList, setClearChatList, ) @@ -178,11 +184,11 @@ const ChatWrapper = () => { } handleSend( - getUrl('chat-messages', isInstalledApp, appId || ''), + getUrl('chat-messages', appSourceType, appId || ''), data, { - onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, isInstalledApp, appId), - onConversationComplete: currentConversationId ? undefined : handleNewConversationCompleted, + onGetSuggestedQuestions: responseItemId => fetchSuggestedQuestions(responseItemId, appSourceType, appId), + onConversationComplete: isHistoryConversation ? undefined : handleNewConversationCompleted, isPublicAPI: !isInstalledApp, }, ) diff --git a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx index f6a8f25cbb..399f16716d 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.spec.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.spec.tsx @@ -5,6 +5,7 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { act, renderHook, waitFor } from '@testing-library/react' import { ToastProvider } from '@/app/components/base/toast' import { + AppSourceType, fetchChatList, fetchConversations, generationConversationName, @@ -49,20 +50,24 @@ vi.mock('../utils', async () => { } }) -vi.mock('@/service/share', () => ({ - fetchChatList: vi.fn(), - fetchConversations: vi.fn(), - generationConversationName: vi.fn(), - fetchAppInfo: vi.fn(), - fetchAppMeta: vi.fn(), - fetchAppParams: vi.fn(), - getAppAccessModeByAppCode: vi.fn(), - delConversation: vi.fn(), - pinConversation: vi.fn(), - renameConversation: vi.fn(), - unpinConversation: vi.fn(), - updateFeedback: vi.fn(), -})) +vi.mock('@/service/share', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + fetchChatList: vi.fn(), + fetchConversations: vi.fn(), + generationConversationName: vi.fn(), + fetchAppInfo: vi.fn(), + fetchAppMeta: vi.fn(), + fetchAppParams: vi.fn(), + getAppAccessModeByAppCode: vi.fn(), + delConversation: vi.fn(), + pinConversation: vi.fn(), + renameConversation: vi.fn(), + unpinConversation: vi.fn(), + updateFeedback: vi.fn(), + } +}) const mockFetchConversations = vi.mocked(fetchConversations) const mockFetchChatList = vi.mocked(fetchChatList) @@ -162,13 +167,13 @@ describe('useChatWithHistory', () => { // Assert await waitFor(() => { - expect(mockFetchConversations).toHaveBeenCalledWith(false, 'app-1', undefined, true, 100) + expect(mockFetchConversations).toHaveBeenCalledWith(AppSourceType.webApp, 'app-1', undefined, true, 100) }) await waitFor(() => { - expect(mockFetchConversations).toHaveBeenCalledWith(false, 'app-1', undefined, false, 100) + expect(mockFetchConversations).toHaveBeenCalledWith(AppSourceType.webApp, 'app-1', undefined, false, 100) }) await waitFor(() => { - expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', false, 'app-1') + expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', AppSourceType.webApp, 'app-1') }) await waitFor(() => { expect(result.current.pinnedConversationList).toEqual(pinnedData.data) @@ -204,7 +209,7 @@ describe('useChatWithHistory', () => { // Assert await waitFor(() => { - expect(mockGenerationConversationName).toHaveBeenCalledWith(false, 'app-1', 'conversation-new') + expect(mockGenerationConversationName).toHaveBeenCalledWith(AppSourceType.webApp, 'app-1', 'conversation-new') }) await waitFor(() => { expect(result.current.conversationList[0]).toEqual(generatedConversation) diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 0ef7aeb5b4..da344a9789 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -29,6 +29,7 @@ import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' import { changeLanguage } from '@/i18n-config/client' import { + AppSourceType, delConversation, pinConversation, renameConversation, @@ -95,6 +96,7 @@ function getFormattedChatList(messages: any[]) { export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const isInstalledApp = useMemo(() => !!installedAppInfo, [installedAppInfo]) + const appSourceType = isInstalledApp ? AppSourceType.installedApp : AppSourceType.webApp const appInfo = useWebAppStore(s => s.appInfo) const appParams = useWebAppStore(s => s.appParams) const appMeta = useWebAppStore(s => s.appMeta) @@ -200,7 +202,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [currentConversationId, newConversationId]) const { data: appPinnedConversationData } = useShareConversations({ - isInstalledApp, + appSourceType, appId, pinned: true, limit: 100, @@ -213,7 +215,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { data: appConversationData, isLoading: appConversationDataLoading, } = useShareConversations({ - isInstalledApp, + appSourceType, appId, pinned: false, limit: 100, @@ -227,7 +229,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { isLoading: appChatListDataLoading, } = useShareChatList({ conversationId: chatShouldReloadKey, - isInstalledApp, + appSourceType, appId, }, { enabled: !!chatShouldReloadKey, @@ -357,10 +359,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const { data: newConversation } = useShareConversationName({ conversationId: newConversationId, - isInstalledApp, + appSourceType, appId, }, { refetchOnWindowFocus: false, + enabled: !!newConversationId, }) const [originConversationList, setOriginConversationList] = useState([]) useEffect(() => { @@ -485,16 +488,16 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [invalidateShareConversations]) const handlePinConversation = useCallback(async (conversationId: string) => { - await pinConversation(isInstalledApp, appId, conversationId) + await pinConversation(appSourceType, appId, conversationId) notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) handleUpdateConversationList() - }, [isInstalledApp, appId, notify, t, handleUpdateConversationList]) + }, [appSourceType, appId, notify, t, handleUpdateConversationList]) const handleUnpinConversation = useCallback(async (conversationId: string) => { - await unpinConversation(isInstalledApp, appId, conversationId) + await unpinConversation(appSourceType, appId, conversationId) notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) handleUpdateConversationList() - }, [isInstalledApp, appId, notify, t, handleUpdateConversationList]) + }, [appSourceType, appId, notify, t, handleUpdateConversationList]) const [conversationDeleting, setConversationDeleting] = useState(false) const handleDeleteConversation = useCallback(async ( @@ -508,7 +511,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { try { setConversationDeleting(true) - await delConversation(isInstalledApp, appId, conversationId) + await delConversation(appSourceType, appId, conversationId) notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) onSuccess() } @@ -543,7 +546,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { setConversationRenaming(true) try { - await renameConversation(isInstalledApp, appId, conversationId, newName) + await renameConversation(appSourceType, appId, conversationId, newName) notify({ type: 'success', @@ -573,9 +576,9 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { }, [handleConversationIdInfoChange, invalidateShareConversations]) const handleFeedback = useCallback(async (messageId: string, feedback: Feedback) => { - await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, isInstalledApp, appId) + await updateFeedback({ url: `/messages/${messageId}/feedbacks`, body: { rating: feedback.rating, content: feedback.content } }, appSourceType, appId) notify({ type: 'success', message: t('api.success', { ns: 'common' }) }) - }, [isInstalledApp, appId, t, notify]) + }, [appSourceType, appId, t, notify]) return { isInstalledApp, diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 11ef672003..0ea46aa930 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -345,7 +345,7 @@ const Answer: FC = ({ data={workflowProcess} item={item} hideProcessDetail={hideProcessDetail} - readonly={hideProcessDetail && appData ? !appData.site.show_workflow_steps : undefined} + readonly={hideProcessDetail && appData ? !appData.site?.show_workflow_steps : undefined} /> ) } diff --git a/web/app/components/base/chat/chat/answer/suggested-questions.tsx b/web/app/components/base/chat/chat/answer/suggested-questions.tsx index 019ed78348..ce997a49b6 100644 --- a/web/app/components/base/chat/chat/answer/suggested-questions.tsx +++ b/web/app/components/base/chat/chat/answer/suggested-questions.tsx @@ -1,6 +1,7 @@ import type { FC } from 'react' import type { ChatItem } from '../../types' import { memo } from 'react' +import { cn } from '@/utils/classnames' import { useChatContext } from '../context' type SuggestedQuestionsProps = { @@ -9,7 +10,7 @@ type SuggestedQuestionsProps = { const SuggestedQuestions: FC = ({ item, }) => { - const { onSend } = useChatContext() + const { onSend, readonly } = useChatContext() const { isOpeningStatement, @@ -24,8 +25,11 @@ const SuggestedQuestions: FC = ({ {suggestedQuestions.filter(q => !!q && q.trim()).map((question, index) => (
onSend?.(question)} + className={cn( + 'system-sm-medium mr-1 mt-1 inline-flex max-w-full shrink-0 cursor-pointer flex-wrap rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3.5 py-2 text-components-button-secondary-accent-text shadow-xs last:mr-0 hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover', + readonly && 'pointer-events-none opacity-50', + )} + onClick={() => !readonly && onSend?.(question)} > {question}
diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 192f46fb23..9de52cb18c 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -5,6 +5,7 @@ import type { } from '../../types' import type { InputForm } from '../type' import type { FileUpload } from '@/app/components/base/features/types' +import { noop } from 'es-toolkit/function' import { decode } from 'html-entities' import Recorder from 'js-audio-recorder' import { @@ -30,6 +31,7 @@ import { useTextAreaHeight } from './hooks' import Operation from './operation' type ChatInputAreaProps = { + readonly?: boolean botName?: string showFeatureBar?: boolean showFileUpload?: boolean @@ -45,6 +47,7 @@ type ChatInputAreaProps = { disabled?: boolean } const ChatInputArea = ({ + readonly, botName, showFeatureBar, showFileUpload, @@ -170,6 +173,7 @@ const ChatInputArea = ({ const operation = (
{ @@ -239,7 +244,14 @@ const ChatInputArea = ({ ) }
- {showFeatureBar && } + {showFeatureBar && ( + + )} ) } diff --git a/web/app/components/base/chat/chat/chat-input-area/operation.tsx b/web/app/components/base/chat/chat/chat-input-area/operation.tsx index 27e5bf6cad..5bce827754 100644 --- a/web/app/components/base/chat/chat/chat-input-area/operation.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/operation.tsx @@ -8,6 +8,7 @@ import { RiMicLine, RiSendPlane2Fill, } from '@remixicon/react' +import { noop } from 'es-toolkit/function' import { memo } from 'react' import ActionButton from '@/app/components/base/action-button' import Button from '@/app/components/base/button' @@ -15,6 +16,7 @@ import { FileUploaderInChatInput } from '@/app/components/base/file-uploader' import { cn } from '@/utils/classnames' type OperationProps = { + readonly?: boolean fileConfig?: FileUpload speechToTextConfig?: EnableType onShowVoiceInput?: () => void @@ -23,6 +25,7 @@ type OperationProps = { ref?: Ref } const Operation: FC = ({ + readonly, ref, fileConfig, speechToTextConfig, @@ -41,11 +44,12 @@ const Operation: FC = ({ ref={ref} >
- {fileConfig?.enabled && } + {fileConfig?.enabled && } { speechToTextConfig?.enabled && ( @@ -56,7 +60,7 @@ const Operation: FC = ({ + { + !hideEditEntrance && ( + + ) + }
)}
diff --git a/web/app/components/base/features/new-feature-panel/index.tsx b/web/app/components/base/features/new-feature-panel/index.tsx index 9e8a2397be..9ee2f407eb 100644 --- a/web/app/components/base/features/new-feature-panel/index.tsx +++ b/web/app/components/base/features/new-feature-panel/index.tsx @@ -2,7 +2,6 @@ import type { OnFeaturesChange } from '@/app/components/base/features/types' import type { InputVar } from '@/app/components/workflow/types' import type { PromptVariable } from '@/models/debug' import { RiCloseLine, RiInformation2Fill } from '@remixicon/react' -import * as React from 'react' import { useTranslation } from 'react-i18next' import AnnotationReply from '@/app/components/base/features/new-feature-panel/annotation-reply' @@ -18,7 +17,6 @@ import SpeechToText from '@/app/components/base/features/new-feature-panel/speec import TextToSpeech from '@/app/components/base/features/new-feature-panel/text-to-speech' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' -import { useDocLink } from '@/context/i18n' type Props = { show: boolean @@ -46,7 +44,6 @@ const NewFeaturePanel = ({ onAutoAddPromptVariable, }: Props) => { const { t } = useTranslation() - const docLink = useDocLink() const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text) const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.tts) @@ -76,14 +73,6 @@ const NewFeaturePanel = ({
diff --git a/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx b/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx index 59b62d0bfd..c9455c98eb 100644 --- a/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx @@ -319,7 +319,7 @@ const ModerationSettingModal: FC = ({
{t('apiBasedExtension.selector.title', { ns: 'common' })}
{ const renderTrigger = useCallback((open: boolean) => { return ( ) }, []) + if (readonly) + return renderTrigger(false) + return ( { return nonceMatch ? nonceMatch[1] : undefined } -const GA: FC = ({ +const GA: FC = async ({ gaType, }) => { if (IS_CE_EDITION) return null const cspHeader = IS_PROD - ? (headers() as unknown as UnsafeUnwrappedHeaders).get('content-security-policy') + ? (await headers()).get('content-security-policy') : null const nonce = extractNonceFromCSP(cspHeader) diff --git a/web/app/components/base/image-uploader/text-generation-image-uploader.tsx b/web/app/components/base/image-uploader/text-generation-image-uploader.tsx index d5ad09ff43..569ff559a2 100644 --- a/web/app/components/base/image-uploader/text-generation-image-uploader.tsx +++ b/web/app/components/base/image-uploader/text-generation-image-uploader.tsx @@ -70,10 +70,12 @@ const PasteImageLinkButton: FC = ({ type TextGenerationImageUploaderProps = { settings: VisionSettings onFilesChange: (files: ImageFile[]) => void + disabled?: boolean } const TextGenerationImageUploader: FC = ({ settings, onFilesChange, + disabled, }) => { const { t } = useTranslation() @@ -93,7 +95,7 @@ const TextGenerationImageUploader: FC = ({ const localUpload = ( = settings.number_limits} + disabled={files.length >= settings.number_limits || disabled} limit={+settings.image_file_size_limit!} > { @@ -115,7 +117,7 @@ const TextGenerationImageUploader: FC = ({ const urlUpload = ( = settings.number_limits} + disabled={files.length >= settings.number_limits || disabled} /> ) diff --git a/web/app/components/base/input-with-copy/index.spec.tsx b/web/app/components/base/input-with-copy/index.spec.tsx index 438e72d142..1a4319603e 100644 --- a/web/app/components/base/input-with-copy/index.spec.tsx +++ b/web/app/components/base/input-with-copy/index.spec.tsx @@ -3,13 +3,8 @@ import * as React from 'react' import { createReactI18nextMock } from '@/test/i18n-mock' import InputWithCopy from './index' -// Create a mock function that we can track using vi.hoisted -const mockCopyToClipboard = vi.hoisted(() => vi.fn(() => true)) - -// Mock the copy-to-clipboard library -vi.mock('copy-to-clipboard', () => ({ - default: mockCopyToClipboard, -})) +// Mock navigator.clipboard for foxact/use-clipboard +const mockWriteText = vi.fn(() => Promise.resolve()) // Mock the i18n hook with custom translations for test assertions vi.mock('react-i18next', () => createReactI18nextMock({ @@ -19,15 +14,16 @@ vi.mock('react-i18next', () => createReactI18nextMock({ 'overview.appInfo.embedded.copied': 'Copied', })) -// Mock es-toolkit/compat debounce -vi.mock('es-toolkit/compat', () => ({ - debounce: (fn: any) => fn, -})) - describe('InputWithCopy component', () => { beforeEach(() => { vi.clearAllMocks() - mockCopyToClipboard.mockClear() + mockWriteText.mockClear() + // Setup navigator.clipboard mock + Object.assign(navigator, { + clipboard: { + writeText: mockWriteText, + }, + }) }) it('renders correctly with default props', () => { @@ -55,7 +51,9 @@ describe('InputWithCopy component', () => { const copyButton = screen.getByRole('button') fireEvent.click(copyButton) - expect(mockCopyToClipboard).toHaveBeenCalledWith('test value') + await waitFor(() => { + expect(mockWriteText).toHaveBeenCalledWith('test value') + }) }) it('copies custom value when copyValue prop is provided', async () => { @@ -65,7 +63,9 @@ describe('InputWithCopy component', () => { const copyButton = screen.getByRole('button') fireEvent.click(copyButton) - expect(mockCopyToClipboard).toHaveBeenCalledWith('custom copy value') + await waitFor(() => { + expect(mockWriteText).toHaveBeenCalledWith('custom copy value') + }) }) it('calls onCopy callback when copy button is clicked', async () => { @@ -76,7 +76,9 @@ describe('InputWithCopy component', () => { const copyButton = screen.getByRole('button') fireEvent.click(copyButton) - expect(onCopyMock).toHaveBeenCalledWith('test value') + await waitFor(() => { + expect(onCopyMock).toHaveBeenCalledWith('test value') + }) }) it('shows copied state after successful copy', async () => { @@ -115,17 +117,19 @@ describe('InputWithCopy component', () => { expect(input).toHaveClass('custom-class') }) - it('handles empty value correctly', () => { + it('handles empty value correctly', async () => { const mockOnChange = vi.fn() render() - const input = screen.getByRole('textbox') + const input = screen.getByDisplayValue('') const copyButton = screen.getByRole('button') expect(input).toBeInTheDocument() expect(copyButton).toBeInTheDocument() fireEvent.click(copyButton) - expect(mockCopyToClipboard).toHaveBeenCalledWith('') + await waitFor(() => { + expect(mockWriteText).toHaveBeenCalledWith('') + }) }) it('maintains focus on input after copy', async () => { diff --git a/web/app/components/base/input-with-copy/index.tsx b/web/app/components/base/input-with-copy/index.tsx index 745e89fb2f..7981ba6236 100644 --- a/web/app/components/base/input-with-copy/index.tsx +++ b/web/app/components/base/input-with-copy/index.tsx @@ -1,10 +1,8 @@ 'use client' import type { InputProps } from '../input' import { RiClipboardFill, RiClipboardLine } from '@remixicon/react' -import copy from 'copy-to-clipboard' -import { debounce } from 'es-toolkit/compat' +import { useClipboard } from 'foxact/use-clipboard' import * as React from 'react' -import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { cn } from '@/utils/classnames' import ActionButton from '../action-button' @@ -30,31 +28,16 @@ const InputWithCopy = React.forwardRef(( ref, ) => { const { t } = useTranslation() - const [isCopied, setIsCopied] = useState(false) // Determine what value to copy const valueToString = typeof value === 'string' ? value : String(value || '') const finalCopyValue = copyValue || valueToString - const onClickCopy = debounce(() => { + const { copied, copy, reset } = useClipboard() + + const handleCopy = () => { copy(finalCopyValue) - setIsCopied(true) onCopy?.(finalCopyValue) - }, 100) - - const onMouseLeave = debounce(() => { - setIsCopied(false) - }, 100) - - useEffect(() => { - if (isCopied) { - const timeout = setTimeout(() => { - setIsCopied(false) - }, 2000) - return () => { - clearTimeout(timeout) - } - } - }, [isCopied]) + } return (
@@ -73,21 +56,21 @@ const InputWithCopy = React.forwardRef(( {showCopyButton && (
- {isCopied + {copied ? ( ) diff --git a/web/app/components/base/markdown/react-markdown-wrapper.spec.tsx b/web/app/components/base/markdown/react-markdown-wrapper.spec.tsx new file mode 100644 index 0000000000..735222011b --- /dev/null +++ b/web/app/components/base/markdown/react-markdown-wrapper.spec.tsx @@ -0,0 +1,109 @@ +import type { PropsWithChildren, ReactNode } from 'react' +import { render, screen } from '@testing-library/react' +import { ReactMarkdownWrapper } from './react-markdown-wrapper' + +vi.mock('@/app/components/base/markdown-blocks', () => ({ + AudioBlock: ({ children }: PropsWithChildren) =>
{children}
, + Img: ({ alt }: { alt?: string }) => {alt}, + Link: ({ children, href }: { children?: ReactNode, href?: string }) =>
{children}, + MarkdownButton: ({ children }: PropsWithChildren) => , + MarkdownForm: ({ children }: PropsWithChildren) =>
{children}
, + Paragraph: ({ children }: PropsWithChildren) =>

{children}

, + PluginImg: ({ alt }: { alt?: string }) => {alt}, + PluginParagraph: ({ children }: PropsWithChildren) =>

{children}

, + ScriptBlock: () => null, + ThinkBlock: ({ children }: PropsWithChildren) =>
{children}
, + VideoBlock: ({ children }: PropsWithChildren) =>
{children}
, +})) + +vi.mock('@/app/components/base/markdown-blocks/code-block', () => ({ + default: ({ children }: PropsWithChildren) => {children}, +})) + +describe('ReactMarkdownWrapper', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Strikethrough rendering', () => { + it('should NOT render single tilde as strikethrough', () => { + // Arrange - single tilde should be rendered as literal text + const content = 'Range: 0.3~8mm' + + // Act + render() + + // Assert - check that ~ is rendered as text, not as strikethrough (del element) + // The content should contain the tilde as literal text + expect(screen.getByText(/0\.3~8mm/)).toBeInTheDocument() + expect(document.querySelector('del')).toBeNull() + }) + + it('should render double tildes as strikethrough', () => { + // Arrange - double tildes should create strikethrough + const content = 'This is ~~strikethrough~~ text' + + // Act + render() + + // Assert - del element should be present for double tildes + const delElement = document.querySelector('del') + expect(delElement).not.toBeNull() + expect(delElement?.textContent).toBe('strikethrough') + }) + + it('should handle mixed content with single and double tildes correctly', () => { + // Arrange - real-world example from issue #31391 + const content = 'PCB thickness: 0.3~8mm and ~~removed feature~~ text' + + // Act + render() + + // Assert + // Only double tildes should create strikethrough + const delElements = document.querySelectorAll('del') + expect(delElements).toHaveLength(1) + expect(delElements[0].textContent).toBe('removed feature') + + // Single tilde should remain as literal text + expect(screen.getByText(/0\.3~8mm/)).toBeInTheDocument() + }) + }) + + describe('Basic rendering', () => { + it('should render plain text content', () => { + // Arrange + const content = 'Hello World' + + // Act + render() + + // Assert + expect(screen.getByText('Hello World')).toBeInTheDocument() + }) + + it('should render bold text', () => { + // Arrange + const content = '**bold text**' + + // Act + render() + + // Assert + expect(screen.getByText('bold text')).toBeInTheDocument() + expect(document.querySelector('strong')).not.toBeNull() + }) + + it('should render italic text', () => { + // Arrange + const content = '*italic text*' + + // Act + render() + + // Assert + expect(screen.getByText('italic text')).toBeInTheDocument() + expect(document.querySelector('em')).not.toBeNull() + }) + }) +}) diff --git a/web/app/components/base/markdown/react-markdown-wrapper.tsx b/web/app/components/base/markdown/react-markdown-wrapper.tsx index 291e6c1980..a3693a561a 100644 --- a/web/app/components/base/markdown/react-markdown-wrapper.tsx +++ b/web/app/components/base/markdown/react-markdown-wrapper.tsx @@ -31,7 +31,7 @@ export const ReactMarkdownWrapper: FC = (props) => { return ( void } @@ -23,6 +25,8 @@ const TabHeader: FC = ({ items, value, itemClassName, + itemWrapClassName, + activeItemClassName, onChange, }) => { const renderItem = ({ id, name, icon, extra, disabled }: Item) => ( @@ -30,8 +34,9 @@ const TabHeader: FC = ({ key={id} className={cn( 'system-md-semibold relative flex cursor-pointer items-center border-b-2 border-transparent pb-2 pt-2.5', - id === value ? 'border-components-tab-active text-text-primary' : 'text-text-tertiary', + id === value ? cn('border-components-tab-active text-text-primary', activeItemClassName) : 'text-text-tertiary', disabled && 'cursor-not-allowed opacity-30', + itemWrapClassName, )} onClick={() => !disabled && onChange(id)} > diff --git a/web/app/components/base/voice-input/index.tsx b/web/app/components/base/voice-input/index.tsx index 4fa2c774f4..52e3c754f8 100644 --- a/web/app/components/base/voice-input/index.tsx +++ b/web/app/components/base/voice-input/index.tsx @@ -8,7 +8,7 @@ import { useParams, usePathname } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { StopCircle } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' -import { audioToText } from '@/service/share' +import { AppSourceType, audioToText } from '@/service/share' import { cn } from '@/utils/classnames' import s from './index.module.css' import { convertToMp3 } from './utils' @@ -108,7 +108,7 @@ const VoiceInput = ({ } try { - const audioResponse = await audioToText(url, isPublic, formData) + const audioResponse = await audioToText(url, isPublic ? AppSourceType.webApp : AppSourceType.installedApp, formData) onConverted(audioResponse.text) onCancel() } diff --git a/web/app/components/billing/progress-bar/index.spec.tsx b/web/app/components/billing/progress-bar/index.spec.tsx index a9c91468de..4eb66dcf79 100644 --- a/web/app/components/billing/progress-bar/index.spec.tsx +++ b/web/app/components/billing/progress-bar/index.spec.tsx @@ -2,24 +2,61 @@ import { render, screen } from '@testing-library/react' import ProgressBar from './index' describe('ProgressBar', () => { - it('renders with provided percent and color', () => { - render() + describe('Normal Mode (determinate)', () => { + it('renders with provided percent and color', () => { + render() - const bar = screen.getByTestId('billing-progress-bar') - expect(bar).toHaveClass('bg-test-color') - expect(bar.getAttribute('style')).toContain('width: 42%') + const bar = screen.getByTestId('billing-progress-bar') + expect(bar).toHaveClass('bg-test-color') + expect(bar.getAttribute('style')).toContain('width: 42%') + }) + + it('caps width at 100% when percent exceeds max', () => { + render() + + const bar = screen.getByTestId('billing-progress-bar') + expect(bar.getAttribute('style')).toContain('width: 100%') + }) + + it('uses the default color when no color prop is provided', () => { + render() + + const bar = screen.getByTestId('billing-progress-bar') + expect(bar).toHaveClass('bg-components-progress-bar-progress-solid') + expect(bar.getAttribute('style')).toContain('width: 20%') + }) }) - it('caps width at 100% when percent exceeds max', () => { - render() + describe('Indeterminate Mode', () => { + it('should render indeterminate progress bar when indeterminate is true', () => { + render() - const bar = screen.getByTestId('billing-progress-bar') - expect(bar.getAttribute('style')).toContain('width: 100%') - }) + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toBeInTheDocument() + expect(bar).toHaveClass('bg-progress-bar-indeterminate-stripe') + }) - it('uses the default color when no color prop is provided', () => { - render() + it('should not render normal progress bar when indeterminate is true', () => { + render() - expect(screen.getByTestId('billing-progress-bar')).toHaveClass('#2970FF') + expect(screen.queryByTestId('billing-progress-bar')).not.toBeInTheDocument() + expect(screen.getByTestId('billing-progress-bar-indeterminate')).toBeInTheDocument() + }) + + it('should render with default width (w-[30px]) when indeterminateFull is false', () => { + render() + + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toHaveClass('w-[30px]') + expect(bar).not.toHaveClass('w-full') + }) + + it('should render with full width (w-full) when indeterminateFull is true', () => { + render() + + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toHaveClass('w-full') + expect(bar).not.toHaveClass('w-[30px]') + }) }) }) diff --git a/web/app/components/billing/progress-bar/index.tsx b/web/app/components/billing/progress-bar/index.tsx index c41fc53310..f16bd952ea 100644 --- a/web/app/components/billing/progress-bar/index.tsx +++ b/web/app/components/billing/progress-bar/index.tsx @@ -3,12 +3,27 @@ import { cn } from '@/utils/classnames' type ProgressBarProps = { percent: number color: string + indeterminate?: boolean + indeterminateFull?: boolean // For Sandbox users: full width stripe } const ProgressBar = ({ percent = 0, - color = '#2970FF', + color = 'bg-components-progress-bar-progress-solid', + indeterminate = false, + indeterminateFull = false, }: ProgressBarProps) => { + if (indeterminate) { + return ( +
+
+
+ ) + } + return (
describe('UsageInfo', () => { - it('renders the metric with a suffix unit and tooltip text', () => { - render( - , - ) + describe('Default Mode (non-storage)', () => { + it('renders the metric with a suffix unit and tooltip text', () => { + render( + , + ) - expect(screen.getByTestId('usage-icon')).toBeInTheDocument() - expect(screen.getByText('Apps')).toBeInTheDocument() - expect(screen.getByText('30')).toBeInTheDocument() - expect(screen.getByText('100')).toBeInTheDocument() - expect(screen.getByText('GB')).toBeInTheDocument() + expect(screen.getByTestId('usage-icon')).toBeInTheDocument() + expect(screen.getByText('Apps')).toBeInTheDocument() + expect(screen.getByText('30')).toBeInTheDocument() + expect(screen.getByText('100')).toBeInTheDocument() + expect(screen.getByText('GB')).toBeInTheDocument() + }) + + it('renders inline unit when unitPosition is inline', () => { + render( + , + ) + + expect(screen.getByText('100GB')).toBeInTheDocument() + }) + + it('shows reset hint text instead of the unit when resetHint is provided', () => { + const resetHint = 'Resets in 3 days' + render( + , + ) + + expect(screen.getByText(resetHint)).toBeInTheDocument() + expect(screen.queryByText('GB')).not.toBeInTheDocument() + }) + + it('displays unlimited text when total is infinite', () => { + render( + , + ) + + expect(screen.getByText('billing.plansCommon.unlimited')).toBeInTheDocument() + }) + + it('applies warning color when usage is close to the limit', () => { + render( + , + ) + + const progressBar = screen.getByTestId('billing-progress-bar') + expect(progressBar).toHaveClass('bg-components-progress-warning-progress') + }) + + it('applies error color when usage exceeds the limit', () => { + render( + , + ) + + const progressBar = screen.getByTestId('billing-progress-bar') + expect(progressBar).toHaveClass('bg-components-progress-error-progress') + }) + + it('does not render the icon when hideIcon is true', () => { + render( + , + ) + + expect(screen.queryByTestId('usage-icon')).not.toBeInTheDocument() + }) }) - it('renders inline unit when unitPosition is inline', () => { - render( - , - ) + describe('Storage Mode', () => { + describe('Below Threshold', () => { + it('should render indeterminate progress bar when usage is below threshold', () => { + render( + , + ) - expect(screen.getByText('100GB')).toBeInTheDocument() - }) + expect(screen.getByTestId('billing-progress-bar-indeterminate')).toBeInTheDocument() + expect(screen.queryByTestId('billing-progress-bar')).not.toBeInTheDocument() + }) - it('shows reset hint text instead of the unit when resetHint is provided', () => { - const resetHint = 'Resets in 3 days' - render( - , - ) + it('should display "< threshold" format when usage is below threshold (non-sandbox)', () => { + render( + , + ) - expect(screen.getByText(resetHint)).toBeInTheDocument() - expect(screen.queryByText('GB')).not.toBeInTheDocument() - }) + // Text "< 50" is rendered inside a single span + expect(screen.getByText(/< 50/)).toBeInTheDocument() + expect(screen.getByText('5120MB')).toBeInTheDocument() + }) - it('displays unlimited text when total is infinite', () => { - render( - , - ) + it('should display "< threshold unit" format when usage is below threshold (sandbox)', () => { + render( + , + ) - expect(screen.getByText('billing.plansCommon.unlimited')).toBeInTheDocument() - }) + // Text "< 50" is rendered inside a single span + expect(screen.getByText(/< 50/)).toBeInTheDocument() + // Unit "MB" appears in the display + expect(screen.getAllByText('MB').length).toBeGreaterThanOrEqual(1) + }) - it('applies warning color when usage is close to the limit', () => { - render( - , - ) + it('should render full-width indeterminate bar for sandbox users below threshold', () => { + render( + , + ) - const progressBar = screen.getByTestId('billing-progress-bar') - expect(progressBar).toHaveClass('bg-components-progress-warning-progress') - }) + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toHaveClass('w-full') + }) - it('applies error color when usage exceeds the limit', () => { - render( - , - ) + it('should render narrow indeterminate bar for non-sandbox users below threshold', () => { + render( + , + ) - const progressBar = screen.getByTestId('billing-progress-bar') - expect(progressBar).toHaveClass('bg-components-progress-error-progress') - }) + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toHaveClass('w-[30px]') + }) + }) - it('does not render the icon when hideIcon is true', () => { - render( - , - ) + describe('Sandbox Full Capacity', () => { + it('should render error color progress bar when sandbox usage >= threshold', () => { + render( + , + ) - expect(screen.queryByTestId('usage-icon')).not.toBeInTheDocument() + const progressBar = screen.getByTestId('billing-progress-bar') + expect(progressBar).toHaveClass('bg-components-progress-error-progress') + }) + + it('should display "threshold / threshold unit" format when sandbox is at full capacity', () => { + render( + , + ) + + // First span: "50", Third span: "50 MB" + expect(screen.getByText('50')).toBeInTheDocument() + expect(screen.getByText(/50 MB/)).toBeInTheDocument() + expect(screen.getByText('/')).toBeInTheDocument() + }) + }) + + describe('Pro/Team Users Above Threshold', () => { + it('should render normal progress bar when usage >= threshold', () => { + render( + , + ) + + expect(screen.getByTestId('billing-progress-bar')).toBeInTheDocument() + expect(screen.queryByTestId('billing-progress-bar-indeterminate')).not.toBeInTheDocument() + }) + + it('should display actual usage when usage >= threshold', () => { + render( + , + ) + + expect(screen.getByText('100')).toBeInTheDocument() + expect(screen.getByText('5120MB')).toBeInTheDocument() + }) + }) + + describe('Storage Tooltip', () => { + it('should render tooltip wrapper when storageTooltip is provided', () => { + const { container } = render( + , + ) + + // Tooltip wrapper should contain cursor-default class + const tooltipWrapper = container.querySelector('.cursor-default') + expect(tooltipWrapper).toBeInTheDocument() + }) + }) }) }) diff --git a/web/app/components/billing/usage-info/index.tsx b/web/app/components/billing/usage-info/index.tsx index 8f0c1bcbcc..f820b85eab 100644 --- a/web/app/components/billing/usage-info/index.tsx +++ b/web/app/components/billing/usage-info/index.tsx @@ -1,5 +1,5 @@ 'use client' -import type { FC } from 'react' +import type { ComponentType, FC } from 'react' import * as React from 'react' import { useTranslation } from 'react-i18next' import Tooltip from '@/app/components/base/tooltip' @@ -9,7 +9,7 @@ import ProgressBar from '../progress-bar' type Props = { className?: string - Icon: any + Icon: ComponentType<{ className?: string }> name: string tooltip?: string usage: number @@ -19,6 +19,11 @@ type Props = { resetHint?: string resetInDays?: number hideIcon?: boolean + // Props for the 50MB threshold display logic + storageMode?: boolean + storageThreshold?: number + storageTooltip?: string + isSandboxPlan?: boolean } const WARNING_THRESHOLD = 80 @@ -35,30 +40,141 @@ const UsageInfo: FC = ({ resetHint, resetInDays, hideIcon = false, + storageMode = false, + storageThreshold = 50, + storageTooltip, + isSandboxPlan = false, }) => { const { t } = useTranslation() + // Special display logic for usage below threshold (only in storage mode) + const isBelowThreshold = storageMode && usage < storageThreshold + // Sandbox at full capacity (usage >= threshold and it's sandbox plan) + const isSandboxFull = storageMode && isSandboxPlan && usage >= storageThreshold + const percent = usage / total * 100 - const color = percent >= 100 - ? 'bg-components-progress-error-progress' - : (percent >= WARNING_THRESHOLD ? 'bg-components-progress-warning-progress' : 'bg-components-progress-bar-progress-solid') + const getProgressColor = () => { + if (percent >= 100) + return 'bg-components-progress-error-progress' + if (percent >= WARNING_THRESHOLD) + return 'bg-components-progress-warning-progress' + return 'bg-components-progress-bar-progress-solid' + } + const color = getProgressColor() const isUnlimited = total === NUM_INFINITE let totalDisplay: string | number = isUnlimited ? t('plansCommon.unlimited', { ns: 'billing' }) : total if (!isUnlimited && unit && unitPosition === 'inline') totalDisplay = `${total}${unit}` const showUnit = !!unit && !isUnlimited && unitPosition === 'suffix' const resetText = resetHint ?? (typeof resetInDays === 'number' ? t('usagePage.resetsIn', { ns: 'billing', count: resetInDays }) : undefined) - const rightInfo = resetText - ? ( + + const renderRightInfo = () => { + if (resetText) { + return (
{resetText}
) - : (showUnit && ( + } + if (showUnit) { + return (
{unit}
- )) + ) + } + return null + } + + // Render usage display + const renderUsageDisplay = () => { + // Storage mode: special display logic + if (storageMode) { + // Sandbox user at full capacity + if (isSandboxFull) { + return ( +
+ + {storageThreshold} + + / + + {storageThreshold} + {' '} + {unit} + +
+ ) + } + // Usage below threshold - show "< 50 MB" or "< 50 / 5GB" + if (isBelowThreshold) { + return ( +
+ + < + {' '} + {storageThreshold} + + {!isSandboxPlan && ( + <> + / + {totalDisplay} + + )} + {isSandboxPlan && {unit}} +
+ ) + } + // Pro/Team users with usage >= threshold - show actual usage + return ( +
+ {usage} + / + {totalDisplay} +
+ ) + } + + // Default display (storageMode = false) + return ( +
+ {usage} + / + {totalDisplay} +
+ ) + } + + const renderWithTooltip = (children: React.ReactNode) => { + if (storageMode && storageTooltip) { + return ( + {storageTooltip}
} + asChild={false} + > +
{children}
+ + ) + } + return children + } + + // Render progress bar with optional tooltip wrapper + const renderProgressBar = () => { + const progressBar = ( + + ) + return renderWithTooltip(progressBar) + } + + const renderUsageWithTooltip = () => { + return renderWithTooltip(renderUsageDisplay()) + } return (
@@ -78,17 +194,10 @@ const UsageInfo: FC = ({ )}
-
- {usage} -
/
-
{totalDisplay}
-
- {rightInfo} + {renderUsageWithTooltip()} + {renderRightInfo()}
- + {renderProgressBar()}
) } diff --git a/web/app/components/billing/usage-info/vector-space-info.spec.tsx b/web/app/components/billing/usage-info/vector-space-info.spec.tsx new file mode 100644 index 0000000000..a811cc9a09 --- /dev/null +++ b/web/app/components/billing/usage-info/vector-space-info.spec.tsx @@ -0,0 +1,305 @@ +import { render, screen } from '@testing-library/react' +import { defaultPlan } from '../config' +import { Plan } from '../type' +import VectorSpaceInfo from './vector-space-info' + +// Mock provider context with configurable plan +let mockPlanType = Plan.sandbox +let mockVectorSpaceUsage = 30 +let mockVectorSpaceTotal = 5120 + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ + plan: { + ...defaultPlan, + type: mockPlanType, + usage: { + ...defaultPlan.usage, + vectorSpace: mockVectorSpaceUsage, + }, + total: { + ...defaultPlan.total, + vectorSpace: mockVectorSpaceTotal, + }, + }, + }), +})) + +describe('VectorSpaceInfo', () => { + beforeEach(() => { + vi.clearAllMocks() + // Reset to default values + mockPlanType = Plan.sandbox + mockVectorSpaceUsage = 30 + mockVectorSpaceTotal = 5120 + }) + + describe('Rendering', () => { + it('should render vector space info component', () => { + render() + + expect(screen.getByText('billing.usagePage.vectorSpace')).toBeInTheDocument() + }) + + it('should apply custom className', () => { + render() + + const container = screen.getByText('billing.usagePage.vectorSpace').closest('.custom-class') + expect(container).toBeInTheDocument() + }) + }) + + describe('Sandbox Plan', () => { + beforeEach(() => { + mockPlanType = Plan.sandbox + mockVectorSpaceUsage = 30 + }) + + it('should render indeterminate progress bar when usage is below threshold', () => { + render() + + expect(screen.getByTestId('billing-progress-bar-indeterminate')).toBeInTheDocument() + }) + + it('should render full-width indeterminate bar for sandbox users', () => { + render() + + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toHaveClass('w-full') + }) + + it('should display "< 50" format for sandbox below threshold', () => { + render() + + expect(screen.getByText(/< 50/)).toBeInTheDocument() + }) + }) + + describe('Sandbox Plan at Full Capacity', () => { + beforeEach(() => { + mockPlanType = Plan.sandbox + mockVectorSpaceUsage = 50 + }) + + it('should render error color progress bar when at full capacity', () => { + render() + + const progressBar = screen.getByTestId('billing-progress-bar') + expect(progressBar).toHaveClass('bg-components-progress-error-progress') + }) + + it('should display "50 / 50 MB" format when at full capacity', () => { + render() + + expect(screen.getByText('50')).toBeInTheDocument() + expect(screen.getByText(/50 MB/)).toBeInTheDocument() + }) + }) + + describe('Professional Plan', () => { + beforeEach(() => { + mockPlanType = Plan.professional + mockVectorSpaceUsage = 30 + }) + + it('should render indeterminate progress bar when usage is below threshold', () => { + render() + + expect(screen.getByTestId('billing-progress-bar-indeterminate')).toBeInTheDocument() + }) + + it('should render narrow indeterminate bar (not full width)', () => { + render() + + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toHaveClass('w-[30px]') + expect(bar).not.toHaveClass('w-full') + }) + + it('should display "< 50 / total" format when below threshold', () => { + render() + + expect(screen.getByText(/< 50/)).toBeInTheDocument() + // 5 GB = 5120 MB + expect(screen.getByText('5120MB')).toBeInTheDocument() + }) + }) + + describe('Professional Plan Above Threshold', () => { + beforeEach(() => { + mockPlanType = Plan.professional + mockVectorSpaceUsage = 100 + }) + + it('should render normal progress bar when usage >= threshold', () => { + render() + + expect(screen.getByTestId('billing-progress-bar')).toBeInTheDocument() + expect(screen.queryByTestId('billing-progress-bar-indeterminate')).not.toBeInTheDocument() + }) + + it('should display actual usage when above threshold', () => { + render() + + expect(screen.getByText('100')).toBeInTheDocument() + expect(screen.getByText('5120MB')).toBeInTheDocument() + }) + }) + + describe('Team Plan', () => { + beforeEach(() => { + mockPlanType = Plan.team + mockVectorSpaceUsage = 30 + }) + + it('should render indeterminate progress bar when usage is below threshold', () => { + render() + + expect(screen.getByTestId('billing-progress-bar-indeterminate')).toBeInTheDocument() + }) + + it('should render narrow indeterminate bar (not full width)', () => { + render() + + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toHaveClass('w-[30px]') + expect(bar).not.toHaveClass('w-full') + }) + + it('should display "< 50 / total" format when below threshold', () => { + render() + + expect(screen.getByText(/< 50/)).toBeInTheDocument() + // 20 GB = 20480 MB + expect(screen.getByText('20480MB')).toBeInTheDocument() + }) + }) + + describe('Team Plan Above Threshold', () => { + beforeEach(() => { + mockPlanType = Plan.team + mockVectorSpaceUsage = 100 + }) + + it('should render normal progress bar when usage >= threshold', () => { + render() + + expect(screen.getByTestId('billing-progress-bar')).toBeInTheDocument() + expect(screen.queryByTestId('billing-progress-bar-indeterminate')).not.toBeInTheDocument() + }) + + it('should display actual usage when above threshold', () => { + render() + + expect(screen.getByText('100')).toBeInTheDocument() + expect(screen.getByText('20480MB')).toBeInTheDocument() + }) + }) + + describe('Pro/Team Plan Warning State', () => { + it('should show warning color when Professional plan usage approaches limit (80%+)', () => { + mockPlanType = Plan.professional + // 5120 MB * 80% = 4096 MB + mockVectorSpaceUsage = 4100 + + render() + + const progressBar = screen.getByTestId('billing-progress-bar') + expect(progressBar).toHaveClass('bg-components-progress-warning-progress') + }) + + it('should show warning color when Team plan usage approaches limit (80%+)', () => { + mockPlanType = Plan.team + // 20480 MB * 80% = 16384 MB + mockVectorSpaceUsage = 16500 + + render() + + const progressBar = screen.getByTestId('billing-progress-bar') + expect(progressBar).toHaveClass('bg-components-progress-warning-progress') + }) + }) + + describe('Pro/Team Plan Error State', () => { + it('should show error color when Professional plan usage exceeds limit', () => { + mockPlanType = Plan.professional + // Exceeds 5120 MB + mockVectorSpaceUsage = 5200 + + render() + + const progressBar = screen.getByTestId('billing-progress-bar') + expect(progressBar).toHaveClass('bg-components-progress-error-progress') + }) + + it('should show error color when Team plan usage exceeds limit', () => { + mockPlanType = Plan.team + // Exceeds 20480 MB + mockVectorSpaceUsage = 21000 + + render() + + const progressBar = screen.getByTestId('billing-progress-bar') + expect(progressBar).toHaveClass('bg-components-progress-error-progress') + }) + }) + + describe('Enterprise Plan (default case)', () => { + beforeEach(() => { + mockPlanType = Plan.enterprise + mockVectorSpaceUsage = 30 + // Enterprise plan uses total.vectorSpace from context + mockVectorSpaceTotal = 102400 // 100 GB = 102400 MB + }) + + it('should use total.vectorSpace from context for enterprise plan', () => { + render() + + // Enterprise plan should use the mockVectorSpaceTotal value (102400MB) + expect(screen.getByText('102400MB')).toBeInTheDocument() + }) + + it('should render indeterminate progress bar when usage is below threshold', () => { + render() + + expect(screen.getByTestId('billing-progress-bar-indeterminate')).toBeInTheDocument() + }) + + it('should render narrow indeterminate bar (not full width) for enterprise', () => { + render() + + const bar = screen.getByTestId('billing-progress-bar-indeterminate') + expect(bar).toHaveClass('w-[30px]') + expect(bar).not.toHaveClass('w-full') + }) + + it('should display "< 50 / total" format when below threshold', () => { + render() + + expect(screen.getByText(/< 50/)).toBeInTheDocument() + expect(screen.getByText('102400MB')).toBeInTheDocument() + }) + }) + + describe('Enterprise Plan Above Threshold', () => { + beforeEach(() => { + mockPlanType = Plan.enterprise + mockVectorSpaceUsage = 100 + mockVectorSpaceTotal = 102400 // 100 GB + }) + + it('should render normal progress bar when usage >= threshold', () => { + render() + + expect(screen.getByTestId('billing-progress-bar')).toBeInTheDocument() + expect(screen.queryByTestId('billing-progress-bar-indeterminate')).not.toBeInTheDocument() + }) + + it('should display actual usage when above threshold', () => { + render() + + expect(screen.getByText('100')).toBeInTheDocument() + expect(screen.getByText('102400MB')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/billing/usage-info/vector-space-info.tsx b/web/app/components/billing/usage-info/vector-space-info.tsx index 11e3a6a1ae..e384ef4d9a 100644 --- a/web/app/components/billing/usage-info/vector-space-info.tsx +++ b/web/app/components/billing/usage-info/vector-space-info.tsx @@ -1,26 +1,44 @@ 'use client' import type { FC } from 'react' +import type { BasicPlan } from '../type' import { RiHardDrive3Line, } from '@remixicon/react' import * as React from 'react' import { useTranslation } from 'react-i18next' import { useProviderContext } from '@/context/provider-context' +import { Plan } from '../type' import UsageInfo from '../usage-info' +import { getPlanVectorSpaceLimitMB } from '../utils' type Props = { className?: string } +// Storage threshold in MB - usage below this shows as "< 50 MB" +const STORAGE_THRESHOLD_MB = getPlanVectorSpaceLimitMB(Plan.sandbox) + const VectorSpaceInfo: FC = ({ className, }) => { const { t } = useTranslation() const { plan } = useProviderContext() const { + type, usage, total, } = plan + + // Determine total based on plan type (in MB), derived from ALL_PLANS config + const getTotalInMB = () => { + const planLimit = getPlanVectorSpaceLimitMB(type as BasicPlan) + // For known plans, use the config value; otherwise fall back to API response + return planLimit > 0 ? planLimit : total.vectorSpace + } + + const totalInMB = getTotalInMB() + const isSandbox = type === Plan.sandbox + return ( = ({ name={t('usagePage.vectorSpace', { ns: 'billing' })} tooltip={t('usagePage.vectorSpaceTooltip', { ns: 'billing' }) as string} usage={usage.vectorSpace} - total={total.vectorSpace} + total={totalInMB} unit="MB" unitPosition="inline" + storageMode + storageThreshold={STORAGE_THRESHOLD_MB} + storageTooltip={t('usagePage.storageThresholdTooltip', { ns: 'billing' }) as string} + isSandboxPlan={isSandbox} /> ) } diff --git a/web/app/components/billing/utils/index.ts b/web/app/components/billing/utils/index.ts index e7192ec351..39fc0cd7b5 100644 --- a/web/app/components/billing/utils/index.ts +++ b/web/app/components/billing/utils/index.ts @@ -1,7 +1,33 @@ -import type { BillingQuota, CurrentPlanInfoBackend } from '../type' +import type { BasicPlan, BillingQuota, CurrentPlanInfoBackend } from '../type' import dayjs from 'dayjs' import { ALL_PLANS, NUM_INFINITE } from '@/app/components/billing/config' +/** + * Parse vectorSpace string from ALL_PLANS config and convert to MB + * @example "50MB" -> 50, "5GB" -> 5120, "20GB" -> 20480 + */ +export const parseVectorSpaceToMB = (vectorSpace: string): number => { + const match = vectorSpace.match(/^(\d+)(MB|GB)$/i) + if (!match) + return 0 + + const value = Number.parseInt(match[1], 10) + const unit = match[2].toUpperCase() + + return unit === 'GB' ? value * 1024 : value +} + +/** + * Get the vector space limit in MB for a given plan type from ALL_PLANS config + */ +export const getPlanVectorSpaceLimitMB = (planType: BasicPlan): number => { + const planInfo = ALL_PLANS[planType] + if (!planInfo) + return 0 + + return parseVectorSpaceToMB(planInfo.vectorSpace) +} + const parseLimit = (limit: number) => { if (limit === 0) return NUM_INFINITE diff --git a/web/app/components/billing/vector-space-full/index.spec.tsx b/web/app/components/billing/vector-space-full/index.spec.tsx index 0382ec0872..375ac54c22 100644 --- a/web/app/components/billing/vector-space-full/index.spec.tsx +++ b/web/app/components/billing/vector-space-full/index.spec.tsx @@ -21,6 +21,18 @@ vi.mock('../upgrade-btn', () => ({ default: () => , })) +// Mock utils to control threshold and plan limits +vi.mock('../utils', () => ({ + getPlanVectorSpaceLimitMB: (planType: string) => { + // Return 5 for sandbox (threshold) and 100 for team + if (planType === 'sandbox') + return 5 + if (planType === 'team') + return 100 + return 0 + }, +})) + describe('VectorSpaceFull', () => { const planMock = { type: 'team', @@ -52,6 +64,6 @@ describe('VectorSpaceFull', () => { render() expect(screen.getByText('8')).toBeInTheDocument() - expect(screen.getByText('10MB')).toBeInTheDocument() + expect(screen.getByText('100MB')).toBeInTheDocument() }) }) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx index 2f5130ecce..3fa940c60d 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/uploader.tsx @@ -54,7 +54,7 @@ const Uploader: FC = ({ setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 1) { notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) return diff --git a/web/app/components/datasets/create/file-uploader/index.tsx b/web/app/components/datasets/create/file-uploader/index.tsx index e9c6693e52..781b97200a 100644 --- a/web/app/components/datasets/create/file-uploader/index.tsx +++ b/web/app/components/datasets/create/file-uploader/index.tsx @@ -278,7 +278,7 @@ const FileUploader = ({ onFileListUpdate?.([...fileListRef.current]) } const fileChangeHandle = useCallback((e: React.ChangeEvent) => { - let files = [...(e.target.files ?? [])] as File[] + let files = Array.from(e.target.files ?? []) as File[] files = files.slice(0, fileUploadConfig.batch_count_limit) initialUpload(files.filter(isValid)) }, [isValid, initialUpload, fileUploadConfig]) diff --git a/web/app/components/datasets/create/step-three/index.spec.tsx b/web/app/components/datasets/create/step-three/index.spec.tsx index 43b4916778..74c5912a1b 100644 --- a/web/app/components/datasets/create/step-three/index.spec.tsx +++ b/web/app/components/datasets/create/step-three/index.spec.tsx @@ -190,7 +190,7 @@ describe('StepThree', () => { // Assert const link = screen.getByText('datasetPipeline.addDocuments.stepThree.learnMore') - expect(link).toHaveAttribute('href', 'https://docs.dify.ai/en-US/guides/knowledge-base/integrate-knowledge-within-application') + expect(link).toHaveAttribute('href', 'https://docs.dify.ai/en-US/use-dify/knowledge/integrate-knowledge-within-application') expect(link).toHaveAttribute('target', '_blank') expect(link).toHaveAttribute('rel', 'noreferrer noopener') }) diff --git a/web/app/components/datasets/create/step-three/index.tsx b/web/app/components/datasets/create/step-three/index.tsx index ad26711311..5ab21f6302 100644 --- a/web/app/components/datasets/create/step-three/index.tsx +++ b/web/app/components/datasets/create/step-three/index.tsx @@ -87,7 +87,7 @@ const StepThree = ({ datasetId, datasetName, indexingType, creationCache, retrie
{t('stepThree.sideTipTitle', { ns: 'datasetCreation' })}
{t('stepThree.sideTipContent', { ns: 'datasetCreation' })}
= ({ {t('form.retrievalSetting.learnMore', { ns: 'datasetSettings' })} diff --git a/web/app/components/datasets/create/website/watercrawl/index.spec.tsx b/web/app/components/datasets/create/website/watercrawl/index.spec.tsx index e694537895..4bb8267cea 100644 --- a/web/app/components/datasets/create/website/watercrawl/index.spec.tsx +++ b/web/app/components/datasets/create/website/watercrawl/index.spec.tsx @@ -24,6 +24,11 @@ vi.mock('@/context/modal-context', () => ({ }), })) +// Mock i18n context +vi.mock('@/context/i18n', () => ({ + useDocLink: () => (path?: string) => path ? `https://docs.dify.ai/en${path}` : 'https://docs.dify.ai/en/', +})) + // ============================================================================ // Test Data Factories // ============================================================================ diff --git a/web/app/components/datasets/documents/components/documents-header.tsx b/web/app/components/datasets/documents/components/documents-header.tsx index ed97742fdd..490893d43f 100644 --- a/web/app/components/datasets/documents/components/documents-header.tsx +++ b/web/app/components/datasets/documents/components/documents-header.tsx @@ -121,7 +121,7 @@ const DocumentsHeader: FC = ({ className="flex items-center text-text-accent" target="_blank" rel="noopener noreferrer" - href={docLink('/guides/knowledge-base/integrate-knowledge-within-application')} + href={docLink('/use-dify/knowledge/integrate-knowledge-within-application')} > {t('list.learnMore', { ns: 'datasetDocuments' })} diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx index a5c03b671a..d02d5927f2 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/index.tsx @@ -230,7 +230,7 @@ const LocalFile = ({ if (!e.dataTransfer) return - let files = [...e.dataTransfer.files] as File[] + let files = Array.from(e.dataTransfer.files) as File[] if (!supportBatchUpload) files = files.slice(0, 1) @@ -251,7 +251,7 @@ const LocalFile = ({ updateFileList([...fileListRef.current]) } const fileChangeHandle = useCallback((e: React.ChangeEvent) => { - let files = [...(e.target.files ?? [])] as File[] + let files = Array.from(e.target.files ?? []) as File[] files = files.slice(0, fileUploadConfig.batch_count_limit) initialUpload(files.filter(isValid)) }, [isValid, initialUpload, fileUploadConfig.batch_count_limit]) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx index 9b0df231bd..4bdaac895b 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx @@ -138,7 +138,7 @@ const OnlineDocuments = ({
{ render() // Assert - expect(mockDocLink).toHaveBeenCalledWith('/guides/knowledge-base/knowledge-pipeline/authorize-data-source') + expect(mockDocLink).toHaveBeenCalledWith('/use-dify/knowledge/knowledge-pipeline/authorize-data-source') }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx index 508745aaeb..4346a2d0af 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx @@ -196,7 +196,7 @@ const OnlineDrive = ({
{ // Assert const link = screen.getByRole('link', { name: 'datasetPipeline.addDocuments.stepThree.learnMore' }) - expect(link).toHaveAttribute('href', 'https://docs.dify.ai/en-US/guides/knowledge-base/integrate-knowledge-within-application') + expect(link).toHaveAttribute('href', 'https://docs.dify.ai/en-US/use-dify/knowledge/knowledge-pipeline/authorize-data-source') expect(link).toHaveAttribute('target', '_blank') expect(link).toHaveAttribute('rel', 'noreferrer noopener') }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/processing/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/processing/index.tsx index 97c8937442..283600fa69 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/processing/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/processing/index.tsx @@ -44,7 +44,7 @@ const Processing = ({
{t('stepThree.sideTipTitle', { ns: 'datasetCreation' })}
{t('stepThree.sideTipContent', { ns: 'datasetCreation' })}
= ({ setDragging(false) if (!e.dataTransfer) return - const files = [...e.dataTransfer.files] + const files = Array.from(e.dataTransfer.files) if (files.length > 1) { notify({ type: 'error', message: t('stepOne.uploader.validation.count', { ns: 'datasetCreation' }) }) return diff --git a/web/app/components/datasets/external-api/external-api-modal/Form.tsx b/web/app/components/datasets/external-api/external-api-modal/Form.tsx index 5b36df6eb4..cd2f02ea8b 100644 --- a/web/app/components/datasets/external-api/external-api-modal/Form.tsx +++ b/web/app/components/datasets/external-api/external-api-modal/Form.tsx @@ -57,7 +57,7 @@ const Form: FC = React.memo(({ {variable === 'endpoint' && ( { render() const docLink = screen.getByText('dataset.externalAPIPanelDocumentation') expect(docLink).toBeInTheDocument() - expect(docLink.closest('a')).toHaveAttribute('href', 'https://docs.example.com/guides/knowledge-base/connect-external-knowledge-base') + expect(docLink.closest('a')).toHaveAttribute('href', 'https://docs.example.com/use-dify/knowledge/external-knowledge-api') }) it('should render create button', () => { diff --git a/web/app/components/datasets/external-api/external-api-panel/index.tsx b/web/app/components/datasets/external-api/external-api-panel/index.tsx index a137348626..6ff5143e01 100644 --- a/web/app/components/datasets/external-api/external-api-panel/index.tsx +++ b/web/app/components/datasets/external-api/external-api-panel/index.tsx @@ -54,7 +54,7 @@ const ExternalAPIPanel: React.FC = ({ onClose }) => {
{t('externalAPIPanelDescription', { ns: 'dataset' })}
diff --git a/web/app/components/datasets/external-knowledge-base/create/InfoPanel.tsx b/web/app/components/datasets/external-knowledge-base/create/InfoPanel.tsx index beb6a3cf71..61b37a0a1d 100644 --- a/web/app/components/datasets/external-knowledge-base/create/InfoPanel.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/InfoPanel.tsx @@ -18,14 +18,14 @@ const InfoPanel = () => { {t('connectDatasetIntro.content.front', { ns: 'dataset' })} - + {t('connectDatasetIntro.content.link', { ns: 'dataset' })} {t('connectDatasetIntro.content.end', { ns: 'dataset' })} diff --git a/web/app/components/datasets/external-knowledge-base/create/index.spec.tsx b/web/app/components/datasets/external-knowledge-base/create/index.spec.tsx index 2fce096cd5..d56833fd36 100644 --- a/web/app/components/datasets/external-knowledge-base/create/index.spec.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/index.spec.tsx @@ -146,7 +146,7 @@ describe('ExternalKnowledgeBaseCreate', () => { renderComponent() const docLink = screen.getByText('dataset.connectHelper.helper4') - expect(docLink).toHaveAttribute('href', 'https://docs.dify.ai/en/guides/knowledge-base/connect-external-knowledge-base') + expect(docLink).toHaveAttribute('href', 'https://docs.dify.ai/en/use-dify/knowledge/connect-external-knowledge-base') expect(docLink).toHaveAttribute('target', '_blank') expect(docLink).toHaveAttribute('rel', 'noopener noreferrer') }) diff --git a/web/app/components/datasets/external-knowledge-base/create/index.tsx b/web/app/components/datasets/external-knowledge-base/create/index.tsx index 1d17b23b43..07b6e71fa6 100644 --- a/web/app/components/datasets/external-knowledge-base/create/index.tsx +++ b/web/app/components/datasets/external-knowledge-base/create/index.tsx @@ -61,7 +61,7 @@ const ExternalKnowledgeBaseCreate: React.FC = {t('connectHelper.helper1', { ns: 'dataset' })} {t('connectHelper.helper2', { ns: 'dataset' })} {t('connectHelper.helper3', { ns: 'dataset' })} - + {t('connectHelper.helper4', { ns: 'dataset' })} diff --git a/web/app/components/datasets/hit-testing/index.spec.tsx b/web/app/components/datasets/hit-testing/index.spec.tsx index 45c68e44b1..6bab3afb6a 100644 --- a/web/app/components/datasets/hit-testing/index.spec.tsx +++ b/web/app/components/datasets/hit-testing/index.spec.tsx @@ -2089,7 +2089,7 @@ describe('Integration: Hit Testing Flow', () => { isLoading: false, } as unknown as ReturnType) - renderWithProviders() + const { container } = renderWithProviders() // Type query const textarea = screen.getByRole('textbox') @@ -2101,11 +2101,8 @@ describe('Integration: Hit Testing Flow', () => { if (submitButton) fireEvent.click(submitButton) - // Wait for the component to update - await waitFor(() => { - // Verify the component is still rendered - expect(screen.getByRole('textbox')).toBeInTheDocument() - }) + // Verify the component is still rendered after submission + expect(container.firstChild).toBeInTheDocument() }) it('should render ResultItem components for non-external results', async () => { @@ -2130,7 +2127,7 @@ describe('Integration: Hit Testing Flow', () => { isLoading: false, } as unknown as ReturnType) - renderWithProviders() + const { container } = renderWithProviders() // Submit a query const textarea = screen.getByRole('textbox') @@ -2141,10 +2138,8 @@ describe('Integration: Hit Testing Flow', () => { if (submitButton) fireEvent.click(submitButton) - await waitFor(() => { - // Verify component is rendered - expect(screen.getByRole('textbox')).toBeInTheDocument() - }) + // Verify component is rendered after submission + expect(container.firstChild).toBeInTheDocument() }) it('should render external results when dataset is external', async () => { diff --git a/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx b/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx index d21297fc93..a942c402ed 100644 --- a/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx +++ b/web/app/components/datasets/hit-testing/modify-retrieval-modal.tsx @@ -96,10 +96,7 @@ const ModifyRetrievalModal: FC = ({ {t('form.retrievalSetting.learnMore', { ns: 'datasetSettings' })} diff --git a/web/app/components/datasets/no-linked-apps-panel.tsx b/web/app/components/datasets/no-linked-apps-panel.tsx index 1b0357bc6a..12e87a7379 100644 --- a/web/app/components/datasets/no-linked-apps-panel.tsx +++ b/web/app/components/datasets/no-linked-apps-panel.tsx @@ -15,7 +15,7 @@ const NoLinkedAppsPanel = () => {
{t('datasetMenus.emptyTip', { ns: 'common' })}
diff --git a/web/app/components/datasets/settings/form/index.tsx b/web/app/components/datasets/settings/form/index.tsx index 5fbaefade7..a25d770518 100644 --- a/web/app/components/datasets/settings/form/index.tsx +++ b/web/app/components/datasets/settings/form/index.tsx @@ -281,7 +281,7 @@ const Form = () => { {t('form.chunkStructure.learnMore', { ns: 'datasetSettings' })} @@ -421,10 +421,7 @@ const Form = () => { {t('form.retrievalSetting.learnMore', { ns: 'datasetSettings' })} diff --git a/web/app/components/explore/app-card/index.spec.tsx b/web/app/components/explore/app-card/index.spec.tsx index 769b317929..152eab92a9 100644 --- a/web/app/components/explore/app-card/index.spec.tsx +++ b/web/app/components/explore/app-card/index.spec.tsx @@ -10,6 +10,7 @@ vi.mock('../../app/type-selector', () => ({ })) const createApp = (overrides?: Partial): App => ({ + can_trial: true, app_id: 'app-id', description: 'App description', copyright: '2024', diff --git a/web/app/components/explore/app-card/index.tsx b/web/app/components/explore/app-card/index.tsx index 0b6cd9920d..5d82ab65cc 100644 --- a/web/app/components/explore/app-card/index.tsx +++ b/web/app/components/explore/app-card/index.tsx @@ -1,8 +1,13 @@ 'use client' import type { App } from '@/models/explore' import { PlusIcon } from '@heroicons/react/20/solid' +import { RiInformation2Line } from '@remixicon/react' +import { useCallback } from 'react' import { useTranslation } from 'react-i18next' +import { useContextSelector } from 'use-context-selector' import AppIcon from '@/app/components/base/app-icon' +import ExploreContext from '@/context/explore-context' +import { useGlobalPublicStore } from '@/context/global-public-context' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' import { AppTypeIcon } from '../../app/type-selector' @@ -23,8 +28,17 @@ const AppCard = ({ }: AppCardProps) => { const { t } = useTranslation() const { app: appBasicInfo } = app + const { systemFeatures } = useGlobalPublicStore() + const isTrialApp = app.can_trial && systemFeatures.enable_trial_app + const setShowTryAppPanel = useContextSelector(ExploreContext, ctx => ctx.setShowTryAppPanel) + const showTryAPPPanel = useCallback((appId: string) => { + return () => { + setShowTryAppPanel?.(true, { appId, app }) + } + }, [setShowTryAppPanel, app]) + return ( -
+
- {isExplore && canCreate && ( + {isExplore && (canCreate || isTrialApp) && ( )} diff --git a/web/app/components/explore/app-list/index.spec.tsx b/web/app/components/explore/app-list/index.spec.tsx index a9e4feeba8..a87d5a2363 100644 --- a/web/app/components/explore/app-list/index.spec.tsx +++ b/web/app/components/explore/app-list/index.spec.tsx @@ -16,9 +16,13 @@ let mockIsError = false const mockHandleImportDSL = vi.fn() const mockHandleImportDSLConfirm = vi.fn() -vi.mock('nuqs', () => ({ - useQueryState: () => [mockTabValue, mockSetTab], -})) +vi.mock('nuqs', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useQueryState: () => [mockTabValue, mockSetTab], + } +}) vi.mock('ahooks', async () => { const actual = await vi.importActual('ahooks') @@ -102,6 +106,7 @@ const createApp = (overrides: Partial = {}): App => ({ description: overrides.app?.description ?? 'Alpha description', use_icon_as_answer_icon: overrides.app?.use_icon_as_answer_icon ?? false, }, + can_trial: true, app_id: overrides.app_id ?? 'app-1', description: overrides.description ?? 'Alpha description', copyright: overrides.copyright ?? '', @@ -127,6 +132,8 @@ const renderWithContext = (hasEditPermission = false, onSuccess?: () => void) => setInstalledApps: vi.fn(), isFetchingInstalledApps: false, setIsFetchingInstalledApps: vi.fn(), + isShowTryAppPanel: false, + setShowTryAppPanel: vi.fn(), }} > diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx index 5b318b780b..1749bde76a 100644 --- a/web/app/components/explore/app-list/index.tsx +++ b/web/app/components/explore/app-list/index.tsx @@ -7,14 +7,17 @@ import { useQueryState } from 'nuqs' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' +import { useContext, useContextSelector } from 'use-context-selector' import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal' +import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' import AppCard from '@/app/components/explore/app-card' +import Banner from '@/app/components/explore/banner/banner' import Category from '@/app/components/explore/category' import CreateAppModal from '@/app/components/explore/create-app-modal' import ExploreContext from '@/context/explore-context' +import { useGlobalPublicStore } from '@/context/global-public-context' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode, @@ -22,6 +25,7 @@ import { import { fetchAppDetail } from '@/service/explore' import { useExploreAppList } from '@/service/use-explore' import { cn } from '@/utils/classnames' +import TryApp from '../try-app' import s from './style.module.css' type AppsProps = { @@ -32,12 +36,19 @@ const Apps = ({ onSuccess, }: AppsProps) => { const { t } = useTranslation() + const { systemFeatures } = useGlobalPublicStore() const { hasEditPermission } = useContext(ExploreContext) const allCategoriesEn = t('apps.allCategories', { ns: 'explore', lng: 'en' }) const [keywords, setKeywords] = useState('') const [searchKeywords, setSearchKeywords] = useState('') + const hasFilterCondition = !!keywords + const handleResetFilter = useCallback(() => { + setKeywords('') + setSearchKeywords('') + }, []) + const { run: handleSearch } = useDebounceFn(() => { setSearchKeywords(keywords) }, { wait: 500 }) @@ -84,6 +95,18 @@ const Apps = ({ isFetching, } = useImportDSL() const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) + + const isShowTryAppPanel = useContextSelector(ExploreContext, ctx => ctx.isShowTryAppPanel) + const setShowTryAppPanel = useContextSelector(ExploreContext, ctx => ctx.setShowTryAppPanel) + const hideTryAppPanel = useCallback(() => { + setShowTryAppPanel(false) + }, [setShowTryAppPanel]) + const appParams = useContextSelector(ExploreContext, ctx => ctx.currentApp) + const handleShowFromTryApp = useCallback(() => { + setCurrApp(appParams?.app || null) + setIsShowCreateModal(true) + }, [appParams?.app]) + const onCreate: CreateAppModalProps['onConfirm'] = async ({ name, icon_type, @@ -91,6 +114,8 @@ const Apps = ({ icon_background, description, }) => { + hideTryAppPanel() + const { export_data } = await fetchAppDetail( currApp?.app.id as string, ) @@ -137,22 +162,24 @@ const Apps = ({ 'flex h-full flex-col border-l-[0.5px] border-divider-regular', )} > - -
-
{t('apps.title', { ns: 'explore' })}
-
{t('apps.description', { ns: 'explore' })}
-
- + {systemFeatures.enable_explore_banner && ( +
+ +
+ )}
- +
+
{!hasFilterCondition ? t('apps.title', { ns: 'explore' }) : t('apps.resultNum', { num: searchFilteredList.length, ns: 'explore' })}
+ {hasFilterCondition && ( + <> +
+ + + )} +
+
+ +
+
) } + + {isShowTryAppPanel && ( + + )}
) } diff --git a/web/app/components/explore/banner/banner-item.tsx b/web/app/components/explore/banner/banner-item.tsx new file mode 100644 index 0000000000..5ce810bafb --- /dev/null +++ b/web/app/components/explore/banner/banner-item.tsx @@ -0,0 +1,187 @@ +/* eslint-disable react-hooks-extra/no-direct-set-state-in-use-effect */ +import type { FC } from 'react' +import type { Banner } from '@/models/app' +import { RiArrowRightLine } from '@remixicon/react' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useCarousel } from '@/app/components/base/carousel' +import { cn } from '@/utils/classnames' +import { IndicatorButton } from './indicator-button' + +type BannerItemProps = { + banner: Banner + autoplayDelay: number + isPaused?: boolean +} + +const RESPONSIVE_BREAKPOINT = 1200 +const MAX_RESPONSIVE_WIDTH = 600 +const INDICATOR_WIDTH = 20 +const INDICATOR_GAP = 8 +const MIN_VIEW_MORE_WIDTH = 480 + +export const BannerItem: FC = ({ banner, autoplayDelay, isPaused = false }) => { + const { t } = useTranslation() + const { api, selectedIndex } = useCarousel() + const { category, title, description, 'img-src': imgSrc } = banner.content + + const [resetKey, setResetKey] = useState(0) + const textAreaRef = useRef(null) + const [maxWidth, setMaxWidth] = useState(undefined) + + const slideInfo = useMemo(() => { + const slides = api?.slideNodes() ?? [] + const totalSlides = slides.length + const nextIndex = totalSlides > 0 ? (selectedIndex + 1) % totalSlides : 0 + return { slides, totalSlides, nextIndex } + }, [api, selectedIndex]) + + const indicatorsWidth = useMemo(() => { + const count = slideInfo.totalSlides + if (count === 0) + return 0 + // Calculate: indicator buttons + gaps + extra spacing (3 * 20px for divider and padding) + return (count + 2) * INDICATOR_WIDTH + (count - 1) * INDICATOR_GAP + }, [slideInfo.totalSlides]) + + const viewMoreStyle = useMemo(() => { + if (!maxWidth) + return undefined + return { + maxWidth: `${maxWidth}px`, + minWidth: indicatorsWidth ? `${Math.min(maxWidth - indicatorsWidth, MIN_VIEW_MORE_WIDTH)}px` : undefined, + } + }, [maxWidth, indicatorsWidth]) + + const responsiveStyle = useMemo( + () => (maxWidth !== undefined ? { maxWidth: `${maxWidth}px` } : undefined), + [maxWidth], + ) + + const incrementResetKey = useCallback(() => setResetKey(prev => prev + 1), []) + + useEffect(() => { + const updateMaxWidth = () => { + if (window.innerWidth < RESPONSIVE_BREAKPOINT && textAreaRef.current) { + const textAreaWidth = textAreaRef.current.offsetWidth + setMaxWidth(Math.min(textAreaWidth, MAX_RESPONSIVE_WIDTH)) + } + else { + setMaxWidth(undefined) + } + } + + updateMaxWidth() + + const resizeObserver = new ResizeObserver(updateMaxWidth) + if (textAreaRef.current) + resizeObserver.observe(textAreaRef.current) + + window.addEventListener('resize', updateMaxWidth) + + return () => { + resizeObserver.disconnect() + window.removeEventListener('resize', updateMaxWidth) + } + }, []) + + useEffect(() => { + incrementResetKey() + }, [selectedIndex, incrementResetKey]) + + const handleBannerClick = useCallback(() => { + incrementResetKey() + if (banner.link) + window.open(banner.link, '_blank', 'noopener,noreferrer') + }, [banner.link, incrementResetKey]) + + const handleIndicatorClick = useCallback((index: number) => { + incrementResetKey() + api?.scrollTo(index) + }, [api, incrementResetKey]) + + return ( +
+ {/* Left content area */} +
+
+ {/* Text section */} +
+ {/* Title area */} +
+

+ {category} +

+

+ {title} +

+
+ {/* Description area */} +
+

+ {description} +

+
+
+ + {/* Actions section */} +
+ {/* View more button */} +
+
+ +
+ + {t('banner.viewMore', { ns: 'explore' })} + +
+ +
+ {/* Slide navigation indicators */} +
+ {slideInfo.slides.map((_: unknown, index: number) => ( + handleIndicatorClick(index)} + /> + ))} +
+
+
+
+
+
+ + {/* Right image area */} +
+ {title} +
+
+ ) +} diff --git a/web/app/components/explore/banner/banner.tsx b/web/app/components/explore/banner/banner.tsx new file mode 100644 index 0000000000..4ec0efdb2b --- /dev/null +++ b/web/app/components/explore/banner/banner.tsx @@ -0,0 +1,94 @@ +import type { FC } from 'react' +import * as React from 'react' +import { useEffect, useMemo, useRef, useState } from 'react' +import { Carousel } from '@/app/components/base/carousel' +import { useLocale } from '@/context/i18n' +import { useGetBanners } from '@/service/use-explore' +import Loading from '../../base/loading' +import { BannerItem } from './banner-item' + +const AUTOPLAY_DELAY = 5000 +const MIN_LOADING_HEIGHT = 168 +const RESIZE_DEBOUNCE_DELAY = 50 + +const LoadingState: FC = () => ( +
+ +
+) + +const Banner: FC = () => { + const locale = useLocale() + const { data: banners, isLoading, isError } = useGetBanners(locale) + const [isHovered, setIsHovered] = useState(false) + const [isResizing, setIsResizing] = useState(false) + const resizeTimerRef = useRef(null) + + const enabledBanners = useMemo( + () => banners?.filter(banner => banner.status === 'enabled') ?? [], + [banners], + ) + + const isPaused = isHovered || isResizing + + // Handle window resize to pause animation + useEffect(() => { + const handleResize = () => { + setIsResizing(true) + + if (resizeTimerRef.current) + clearTimeout(resizeTimerRef.current) + + resizeTimerRef.current = setTimeout(() => { + setIsResizing(false) + }, RESIZE_DEBOUNCE_DELAY) + } + + window.addEventListener('resize', handleResize) + + return () => { + window.removeEventListener('resize', handleResize) + if (resizeTimerRef.current) + clearTimeout(resizeTimerRef.current) + } + }, []) + + if (isLoading) + return + + if (isError || enabledBanners.length === 0) + return null + + return ( + setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + > + + {enabledBanners.map(banner => ( + + + + ))} + + + ) +} + +export default React.memo(Banner) diff --git a/web/app/components/explore/banner/indicator-button.tsx b/web/app/components/explore/banner/indicator-button.tsx new file mode 100644 index 0000000000..332dae53ba --- /dev/null +++ b/web/app/components/explore/banner/indicator-button.tsx @@ -0,0 +1,112 @@ +/* eslint-disable react-hooks-extra/no-direct-set-state-in-use-effect */ +import type { FC } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' +import { cn } from '@/utils/classnames' + +type IndicatorButtonProps = { + index: number + selectedIndex: number + isNextSlide: boolean + autoplayDelay: number + resetKey: number + isPaused?: boolean + onClick: () => void +} + +const PROGRESS_MAX = 100 +const DEGREES_PER_PERCENT = 3.6 + +export const IndicatorButton: FC = ({ + index, + selectedIndex, + isNextSlide, + autoplayDelay, + resetKey, + isPaused = false, + onClick, +}) => { + const [progress, setProgress] = useState(0) + const frameIdRef = useRef(undefined) + const startTimeRef = useRef(0) + + const isActive = index === selectedIndex + const shouldAnimate = !document.hidden && !isPaused + + useEffect(() => { + if (!isNextSlide) { + setProgress(0) + if (frameIdRef.current) + cancelAnimationFrame(frameIdRef.current) + return + } + + setProgress(0) + startTimeRef.current = Date.now() + + const animate = () => { + if (!document.hidden && !isPaused) { + const elapsed = Date.now() - startTimeRef.current + const newProgress = Math.min((elapsed / autoplayDelay) * PROGRESS_MAX, PROGRESS_MAX) + setProgress(newProgress) + + if (newProgress < PROGRESS_MAX) + frameIdRef.current = requestAnimationFrame(animate) + } + else { + frameIdRef.current = requestAnimationFrame(animate) + } + } + + if (shouldAnimate) + frameIdRef.current = requestAnimationFrame(animate) + + return () => { + if (frameIdRef.current) + cancelAnimationFrame(frameIdRef.current) + } + }, [isNextSlide, autoplayDelay, resetKey, isPaused]) + + const handleClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + onClick() + }, [onClick]) + + const progressDegrees = progress * DEGREES_PER_PERCENT + + return ( + + ) +} diff --git a/web/app/components/explore/category.tsx b/web/app/components/explore/category.tsx index 97a9ca92b3..47c2a4e3a7 100644 --- a/web/app/components/explore/category.tsx +++ b/web/app/components/explore/category.tsx @@ -29,7 +29,7 @@ const Category: FC = ({ const isAllCategories = !list.includes(value as AppCategory) || value === allCategoriesEn const itemClassName = (isSelected: boolean) => cn( - 'flex h-[32px] cursor-pointer items-center rounded-lg border-[0.5px] border-transparent px-3 py-[7px] font-medium leading-[18px] text-text-tertiary hover:bg-components-main-nav-nav-button-bg-active', + 'system-sm-medium flex h-7 cursor-pointer items-center rounded-lg border border-transparent px-3 text-text-tertiary hover:bg-components-main-nav-nav-button-bg-active', isSelected && 'border-components-main-nav-nav-button-border bg-components-main-nav-nav-button-bg-active text-components-main-nav-nav-button-text-active shadow-xs', ) diff --git a/web/app/components/explore/index.tsx b/web/app/components/explore/index.tsx index 30132eea66..0b5e18a1de 100644 --- a/web/app/components/explore/index.tsx +++ b/web/app/components/explore/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { FC } from 'react' +import type { CurrentTryAppParams } from '@/context/explore-context' import type { InstalledApp } from '@/models/explore' import { useRouter } from 'next/navigation' import * as React from 'react' @@ -41,6 +42,16 @@ const Explore: FC = ({ return router.replace('/datasets') }, [isCurrentWorkspaceDatasetOperator]) + const [currentTryAppParams, setCurrentTryAppParams] = useState(undefined) + const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false) + const setShowTryAppPanel = (showTryAppPanel: boolean, params?: CurrentTryAppParams) => { + if (showTryAppPanel) + setCurrentTryAppParams(params) + else + setCurrentTryAppParams(undefined) + setIsShowTryAppPanel(showTryAppPanel) + } + return (
= ({ setInstalledApps, isFetchingInstalledApps, setIsFetchingInstalledApps, + currentApp: currentTryAppParams, + isShowTryAppPanel, + setShowTryAppPanel, } } > diff --git a/web/app/components/explore/installed-app/index.tsx b/web/app/components/explore/installed-app/index.tsx index def66c0260..7366057445 100644 --- a/web/app/components/explore/installed-app/index.tsx +++ b/web/app/components/explore/installed-app/index.tsx @@ -1,5 +1,6 @@ 'use client' import type { FC } from 'react' +import type { AccessMode } from '@/models/access-control' import type { AppData } from '@/models/share' import * as React from 'react' import { useEffect } from 'react' @@ -62,8 +63,8 @@ const InstalledApp: FC = ({ if (appMeta) updateWebAppMeta(appMeta) if (webAppAccessMode) - updateWebAppAccessMode(webAppAccessMode.accessMode) - updateUserCanAccessApp(Boolean(userCanAccessApp && userCanAccessApp?.result)) + updateWebAppAccessMode((webAppAccessMode as { accessMode: AccessMode }).accessMode) + updateUserCanAccessApp(Boolean(userCanAccessApp && (userCanAccessApp as { result: boolean })?.result)) }, [installedApp, appMeta, appParams, updateAppInfo, updateAppParams, updateUserCanAccessApp, updateWebAppMeta, userCanAccessApp, webAppAccessMode, updateWebAppAccessMode]) if (appParamsError) { diff --git a/web/app/components/explore/sidebar/app-nav-item/index.tsx b/web/app/components/explore/sidebar/app-nav-item/index.tsx index 3347efeb3f..08558578f6 100644 --- a/web/app/components/explore/sidebar/app-nav-item/index.tsx +++ b/web/app/components/explore/sidebar/app-nav-item/index.tsx @@ -56,7 +56,7 @@ export default function AppNavItem({ <>
-
{name}
+
{name}
e.stopPropagation()}> { setInstalledApps: vi.fn(), isFetchingInstalledApps: false, setIsFetchingInstalledApps: vi.fn(), - }} + } as unknown as IExplore} > , @@ -97,8 +98,8 @@ describe('SideBar', () => { renderWithContext(mockInstalledApps) // Assert - expect(screen.getByText('explore.sidebar.discovery')).toBeInTheDocument() - expect(screen.getByText('explore.sidebar.workspace')).toBeInTheDocument() + expect(screen.getByText('explore.sidebar.title')).toBeInTheDocument() + expect(screen.getByText('explore.sidebar.webApps')).toBeInTheDocument() expect(screen.getByText('My App')).toBeInTheDocument() }) }) diff --git a/web/app/components/explore/sidebar/index.tsx b/web/app/components/explore/sidebar/index.tsx index 1257886165..3e9b664580 100644 --- a/web/app/components/explore/sidebar/index.tsx +++ b/web/app/components/explore/sidebar/index.tsx @@ -1,5 +1,7 @@ 'use client' import type { FC } from 'react' +import { RiAppsFill, RiExpandRightLine, RiLayoutLeft2Line } from '@remixicon/react' +import { useBoolean } from 'ahooks' import Link from 'next/link' import { useSelectedLayoutSegments } from 'next/navigation' import * as React from 'react' @@ -14,18 +16,7 @@ import { useGetInstalledApps, useUninstallApp, useUpdateAppPinStatus } from '@/s import { cn } from '@/utils/classnames' import Toast from '../../base/toast' import Item from './app-nav-item' - -const SelectedDiscoveryIcon = () => ( - - - -) - -const DiscoveryIcon = () => ( - - - -) +import NoApps from './no-apps' export type IExploreSideBarProps = { controlUpdateInstalledApps: number @@ -45,6 +36,9 @@ const SideBar: FC = ({ const media = useBreakpoints() const isMobile = media === MediaType.mobile + const [isFold, { + toggle: toggleIsFold, + }] = useBoolean(false) const [showConfirm, setShowConfirm] = useState(false) const [currId, setCurrId] = useState('') @@ -84,22 +78,31 @@ const SideBar: FC = ({ const pinnedAppsCount = installedApps.filter(({ is_pinned }) => is_pinned).length return ( -
+
- {isDiscoverySelected ? : } - {!isMobile &&
{t('sidebar.discovery', { ns: 'explore' })}
} +
+ +
+ {!isMobile && !isFold &&
{t('sidebar.title', { ns: 'explore' })}
}
+ + {installedApps.length === 0 && !isMobile && !isFold + && ( +
+ +
+ )} + {installedApps.length > 0 && ( -
-

{t('sidebar.workspace', { ns: 'explore' })}

+
+ {!isMobile && !isFold &&

{t('sidebar.webApps', { ns: 'explore' })}

}
= ({ {installedApps.map(({ id, is_pinned, uninstallable, app: { name, icon_type, icon, icon_url, icon_background } }, index) => ( = ({
)} + + {!isMobile && ( +
+ {isFold + ? + : ( + + )} +
+ )} + {showConfirm && ( { + const { t } = useTranslation() + const { theme } = useTheme() + return ( +
+
+
{t(`${i18nPrefix}.title`, { ns: 'explore' })}
+
{t(`${i18nPrefix}.description`, { ns: 'explore' })}
+ {t(`${i18nPrefix}.learnMore`, { ns: 'explore' })} +
+ ) +} +export default React.memo(NoApps) diff --git a/web/app/components/explore/sidebar/no-apps/no-web-apps-dark.png b/web/app/components/explore/sidebar/no-apps/no-web-apps-dark.png new file mode 100644 index 0000000000..e153686fcd Binary files /dev/null and b/web/app/components/explore/sidebar/no-apps/no-web-apps-dark.png differ diff --git a/web/app/components/explore/sidebar/no-apps/no-web-apps-light.png b/web/app/components/explore/sidebar/no-apps/no-web-apps-light.png new file mode 100644 index 0000000000..2416b957d2 Binary files /dev/null and b/web/app/components/explore/sidebar/no-apps/no-web-apps-light.png differ diff --git a/web/app/components/explore/sidebar/no-apps/style.module.css b/web/app/components/explore/sidebar/no-apps/style.module.css new file mode 100644 index 0000000000..ad3787ce2b --- /dev/null +++ b/web/app/components/explore/sidebar/no-apps/style.module.css @@ -0,0 +1,7 @@ +.light { + background-image: url('./no-web-apps-light.png'); +} + +.dark { + background-image: url('./no-web-apps-dark.png'); +} diff --git a/web/app/components/explore/try-app/app-info/index.tsx b/web/app/components/explore/try-app/app-info/index.tsx new file mode 100644 index 0000000000..eab265bd04 --- /dev/null +++ b/web/app/components/explore/try-app/app-info/index.tsx @@ -0,0 +1,95 @@ +'use client' +import type { FC } from 'react' +import type { TryAppInfo } from '@/service/try-app' +import { RiAddLine } from '@remixicon/react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import { AppTypeIcon } from '@/app/components/app/type-selector' +import AppIcon from '@/app/components/base/app-icon' +import Button from '@/app/components/base/button' +import { cn } from '@/utils/classnames' +import useGetRequirements from './use-get-requirements' + +type Props = { + appId: string + appDetail: TryAppInfo + category?: string + className?: string + onCreate: () => void +} + +const headerClassName = 'system-sm-semibold-uppercase text-text-secondary mb-3' + +const AppInfo: FC = ({ + appId, + className, + category, + appDetail, + onCreate, +}) => { + const { t } = useTranslation() + const mode = appDetail?.mode + const { requirements } = useGetRequirements({ appDetail, appId }) + return ( +
+ {/* name and icon */} +
+
+ + +
+
+
+
{appDetail.name}
+
+
+ {mode === 'advanced-chat' &&
{t('types.advanced', { ns: 'app' }).toUpperCase()}
} + {mode === 'chat' &&
{t('types.chatbot', { ns: 'app' }).toUpperCase()}
} + {mode === 'agent-chat' &&
{t('types.agent', { ns: 'app' }).toUpperCase()}
} + {mode === 'workflow' &&
{t('types.workflow', { ns: 'app' }).toUpperCase()}
} + {mode === 'completion' &&
{t('types.completion', { ns: 'app' }).toUpperCase()}
} +
+
+
+ {appDetail.description && ( +
{appDetail.description}
+ )} + + + {category && ( +
+
{t('tryApp.category', { ns: 'explore' })}
+
{category}
+
+ )} + {requirements.length > 0 && ( +
+
{t('tryApp.requirements', { ns: 'explore' })}
+
+ {requirements.map(item => ( +
+
+
{item.name}
+
+ ))} +
+
+ )} + +
+ ) +} +export default React.memo(AppInfo) diff --git a/web/app/components/explore/try-app/app-info/use-get-requirements.ts b/web/app/components/explore/try-app/app-info/use-get-requirements.ts new file mode 100644 index 0000000000..976989be73 --- /dev/null +++ b/web/app/components/explore/try-app/app-info/use-get-requirements.ts @@ -0,0 +1,78 @@ +import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types' +import type { ToolNodeType } from '@/app/components/workflow/nodes/tool/types' +import type { TryAppInfo } from '@/service/try-app' +import type { AgentTool } from '@/types/app' +import { uniqBy } from 'es-toolkit/compat' +import { BlockEnum } from '@/app/components/workflow/types' +import { MARKETPLACE_API_PREFIX } from '@/config' +import { useGetTryAppFlowPreview } from '@/service/use-try-app' + +type Params = { + appDetail: TryAppInfo + appId: string +} + +type RequirementItem = { + name: string + iconUrl: string +} +const getIconUrl = (provider: string, tool: string) => { + return `${MARKETPLACE_API_PREFIX}/plugins/${provider}/${tool}/icon` +} + +const useGetRequirements = ({ appDetail, appId }: Params) => { + const isBasic = ['chat', 'completion', 'agent-chat'].includes(appDetail.mode) + const isAgent = appDetail.mode === 'agent-chat' + const isAdvanced = !isBasic + const { data: flowData } = useGetTryAppFlowPreview(appId, isBasic) + + const requirements: RequirementItem[] = [] + if (isBasic) { + const modelProviderAndName = appDetail.model_config.model.provider.split('/') + const name = appDetail.model_config.model.provider.split('/').pop() || '' + requirements.push({ + name, + iconUrl: getIconUrl(modelProviderAndName[0], modelProviderAndName[1]), + }) + } + if (isAgent) { + requirements.push(...appDetail.model_config.agent_mode.tools.filter(data => (data as AgentTool).enabled).map((data) => { + const tool = data as AgentTool + const modelProviderAndName = tool.provider_id.split('/') + return { + name: tool.tool_label, + iconUrl: getIconUrl(modelProviderAndName[0], modelProviderAndName[1]), + } + })) + } + if (isAdvanced && flowData && flowData?.graph?.nodes?.length > 0) { + const nodes = flowData.graph.nodes + const llmNodes = nodes.filter(node => node.data.type === BlockEnum.LLM) + requirements.push(...llmNodes.map((node) => { + const data = node.data as LLMNodeType + const modelProviderAndName = data.model.provider.split('/') + return { + name: data.model.name, + iconUrl: getIconUrl(modelProviderAndName[0], modelProviderAndName[1]), + } + })) + + const toolNodes = nodes.filter(node => node.data.type === BlockEnum.Tool) + requirements.push(...toolNodes.map((node) => { + const data = node.data as ToolNodeType + const toolProviderAndName = data.provider_id.split('/') + return { + name: data.tool_label, + iconUrl: getIconUrl(toolProviderAndName[0], toolProviderAndName[1]), + } + })) + } + + const uniqueRequirements = uniqBy(requirements, 'name') + + return { + requirements: uniqueRequirements, + } +} + +export default useGetRequirements diff --git a/web/app/components/explore/try-app/app/chat.tsx b/web/app/components/explore/try-app/app/chat.tsx new file mode 100644 index 0000000000..b6b4a76ad5 --- /dev/null +++ b/web/app/components/explore/try-app/app/chat.tsx @@ -0,0 +1,104 @@ +'use client' +import type { FC } from 'react' +import type { + EmbeddedChatbotContextValue, +} from '@/app/components/base/chat/embedded-chatbot/context' +import type { TryAppInfo } from '@/service/try-app' +import { RiResetLeftLine } from '@remixicon/react' +import { useBoolean } from 'ahooks' +import * as React from 'react' +import { useEffect } from 'react' +import { useTranslation } from 'react-i18next' +import ActionButton from '@/app/components/base/action-button' +import Alert from '@/app/components/base/alert' +import AppIcon from '@/app/components/base/app-icon' +import ChatWrapper from '@/app/components/base/chat/embedded-chatbot/chat-wrapper' +import { + EmbeddedChatbotContext, +} from '@/app/components/base/chat/embedded-chatbot/context' +import { + useEmbeddedChatbot, +} from '@/app/components/base/chat/embedded-chatbot/hooks' +import ViewFormDropdown from '@/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown' +import Tooltip from '@/app/components/base/tooltip' +import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { AppSourceType } from '@/service/share' +import { cn } from '@/utils/classnames' +import { useThemeContext } from '../../../base/chat/embedded-chatbot/theme/theme-context' + +type Props = { + appId: string + appDetail: TryAppInfo + className: string +} + +const TryApp: FC = ({ + appId, + appDetail, + className, +}) => { + const { t } = useTranslation() + const media = useBreakpoints() + const isMobile = media === MediaType.mobile + const themeBuilder = useThemeContext() + const { removeConversationIdInfo, ...chatData } = useEmbeddedChatbot(AppSourceType.tryApp, appId) + const currentConversationId = chatData.currentConversationId + const inputsForms = chatData.inputsForms + useEffect(() => { + if (appId) + removeConversationIdInfo(appId) + }, [appId]) + const [isHideTryNotice, { + setTrue: hideTryNotice, + }] = useBoolean(false) + + const handleNewConversation = () => { + removeConversationIdInfo(appId) + chatData.handleNewConversation() + } + return ( + +
+
+
+ +
{appDetail.name}
+
+
+ {currentConversationId && ( + + + + + + )} + {currentConversationId && inputsForms.length > 0 && ( + + )} +
+
+
+ {!isHideTryNotice && ( + + )} + +
+
+
+ ) +} +export default React.memo(TryApp) diff --git a/web/app/components/explore/try-app/app/index.tsx b/web/app/components/explore/try-app/app/index.tsx new file mode 100644 index 0000000000..f5dc14510d --- /dev/null +++ b/web/app/components/explore/try-app/app/index.tsx @@ -0,0 +1,44 @@ +'use client' +import type { FC } from 'react' +import type { AppData } from '@/models/share' +import type { TryAppInfo } from '@/service/try-app' +import * as React from 'react' +import useDocumentTitle from '@/hooks/use-document-title' +import Chat from './chat' +import TextGeneration from './text-generation' + +type Props = { + appId: string + appDetail: TryAppInfo +} + +const TryApp: FC = ({ + appId, + appDetail, +}) => { + const mode = appDetail?.mode + const isChat = ['chat', 'advanced-chat', 'agent-chat'].includes(mode!) + const isCompletion = !isChat + + useDocumentTitle(appDetail?.site?.title || '') + return ( +
+ {isChat && ( + + )} + {isCompletion && ( + + )} +
+ ) +} +export default React.memo(TryApp) diff --git a/web/app/components/explore/try-app/app/text-generation.tsx b/web/app/components/explore/try-app/app/text-generation.tsx new file mode 100644 index 0000000000..3189e621e9 --- /dev/null +++ b/web/app/components/explore/try-app/app/text-generation.tsx @@ -0,0 +1,262 @@ +'use client' +import type { FC } from 'react' +import type { InputValueTypes, Task } from '../../../share/text-generation/types' +import type { MoreLikeThisConfig, PromptConfig, TextToSpeechConfig } from '@/models/debug' +import type { AppData, CustomConfigValueType, SiteInfo } from '@/models/share' +import type { VisionFile, VisionSettings } from '@/types/app' +import { useBoolean } from 'ahooks' +import { noop } from 'es-toolkit/function' +import * as React from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Alert from '@/app/components/base/alert' +import AppIcon from '@/app/components/base/app-icon' +import Loading from '@/app/components/base/loading' +import Res from '@/app/components/share/text-generation/result' +import { TaskStatus } from '@/app/components/share/text-generation/types' +import { appDefaultIconBackground } from '@/config' +import { useWebAppStore } from '@/context/web-app-context' +import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { AppSourceType } from '@/service/share' +import { useGetTryAppParams } from '@/service/use-try-app' +import { Resolution, TransferMethod } from '@/types/app' +import { cn } from '@/utils/classnames' +import { userInputsFormToPromptVariables } from '@/utils/model-config' +import RunOnce from '../../../share/text-generation/run-once' + +type Props = { + appId: string + className?: string + isWorkflow?: boolean + appData: AppData | null +} + +const TextGeneration: FC = ({ + appId, + className, + isWorkflow, + appData, +}) => { + const { t } = useTranslation() + const media = useBreakpoints() + const isPC = media === MediaType.pc + + const [inputs, doSetInputs] = useState>({}) + const inputsRef = useRef>(inputs) + const setInputs = useCallback((newInputs: Record) => { + doSetInputs(newInputs) + inputsRef.current = newInputs + }, []) + + const updateAppInfo = useWebAppStore(s => s.updateAppInfo) + const { data: tryAppParams } = useGetTryAppParams(appId) + + const updateAppParams = useWebAppStore(s => s.updateAppParams) + const appParams = useWebAppStore(s => s.appParams) + const [siteInfo, setSiteInfo] = useState(null) + const [promptConfig, setPromptConfig] = useState(null) + const [customConfig, setCustomConfig] = useState | null>(null) + const [moreLikeThisConfig, setMoreLikeThisConfig] = useState(null) + const [textToSpeechConfig, setTextToSpeechConfig] = useState(null) + const [controlSend, setControlSend] = useState(0) + const [visionConfig, setVisionConfig] = useState({ + enabled: false, + number_limits: 2, + detail: Resolution.low, + transfer_methods: [TransferMethod.local_file], + }) + const [completionFiles, setCompletionFiles] = useState([]) + const [isShowResultPanel, { setTrue: doShowResultPanel, setFalse: hideResultPanel }] = useBoolean(false) + const showResultPanel = () => { + // fix: useClickAway hideResSidebar will close sidebar + setTimeout(() => { + doShowResultPanel() + }, 0) + } + + const handleSend = () => { + setControlSend(Date.now()) + showResultPanel() + } + + const [resultExisted, setResultExisted] = useState(false) + + useEffect(() => { + if (!appData) + return + updateAppInfo(appData) + }, [appData, updateAppInfo]) + + useEffect(() => { + if (!tryAppParams) + return + updateAppParams(tryAppParams) + }, [tryAppParams, updateAppParams]) + + useEffect(() => { + (async () => { + if (!appData || !appParams) + return + const { site: siteInfo, custom_config } = appData + setSiteInfo(siteInfo as SiteInfo) + setCustomConfig(custom_config) + + const { user_input_form, more_like_this, file_upload, text_to_speech } = appParams + setVisionConfig({ + // legacy of image upload compatible + ...file_upload, + transfer_methods: file_upload?.allowed_file_upload_methods || file_upload?.allowed_upload_methods, + // legacy of image upload compatible + image_file_size_limit: appParams?.system_parameters.image_file_size_limit, + fileUploadConfig: appParams?.system_parameters, + // eslint-disable-next-line ts/no-explicit-any + } as any) + const prompt_variables = userInputsFormToPromptVariables(user_input_form) + setPromptConfig({ + prompt_template: '', // placeholder for future + prompt_variables, + } as PromptConfig) + setMoreLikeThisConfig(more_like_this) + setTextToSpeechConfig(text_to_speech) + })() + }, [appData, appParams]) + + const [isCompleted, setIsCompleted] = useState(false) + const handleCompleted = useCallback(() => { + setIsCompleted(true) + }, []) + const [isHideTryNotice, { + setTrue: hideTryNotice, + }] = useBoolean(false) + + const renderRes = (task?: Task) => ( + setResultExisted(true)} + /> + ) + + const renderResWrap = ( +
+
+ {isCompleted && !isHideTryNotice && ( + + )} + {renderRes()} +
+
+ ) + + if (!siteInfo || !promptConfig) { + return ( +
+ +
+ ) + } + + return ( +
+ {/* Left */} +
+ {/* Header */} +
+
+ +
{siteInfo.title}
+
+ {siteInfo.description && ( +
{siteInfo.description}
+ )} +
+ {/* form */} +
+ +
+
+ + {/* Result */} +
+ {!isPC && ( +
{ + if (isShowResultPanel) + hideResultPanel() + else + showResultPanel() + }} + > +
+
+ )} + {renderResWrap} +
+
+ ) +} + +export default React.memo(TextGeneration) diff --git a/web/app/components/explore/try-app/index.tsx b/web/app/components/explore/try-app/index.tsx new file mode 100644 index 0000000000..b2e2b72140 --- /dev/null +++ b/web/app/components/explore/try-app/index.tsx @@ -0,0 +1,74 @@ +/* eslint-disable style/multiline-ternary */ +'use client' +import type { FC } from 'react' +import { RiCloseLine } from '@remixicon/react' +import * as React from 'react' +import { useState } from 'react' +import Loading from '@/app/components/base/loading' +import Modal from '@/app/components/base/modal/index' +import { useGetTryAppInfo } from '@/service/use-try-app' +import Button from '../../base/button' +import App from './app' +import AppInfo from './app-info' +import Preview from './preview' +import Tab, { TypeEnum } from './tab' + +type Props = { + appId: string + category?: string + onClose: () => void + onCreate: () => void +} + +const TryApp: FC = ({ + appId, + category, + onClose, + onCreate, +}) => { + const [type, setType] = useState(TypeEnum.TRY) + const { data: appDetail, isLoading } = useGetTryAppInfo(appId) + + return ( + + {isLoading ? ( +
+ +
+ ) : ( +
+
+ + +
+ {/* Main content */} +
+ {type === TypeEnum.TRY ? : } + +
+
+ )} +
+ ) +} +export default React.memo(TryApp) diff --git a/web/app/components/explore/try-app/preview/basic-app-preview.tsx b/web/app/components/explore/try-app/preview/basic-app-preview.tsx new file mode 100644 index 0000000000..6954546b2e --- /dev/null +++ b/web/app/components/explore/try-app/preview/basic-app-preview.tsx @@ -0,0 +1,367 @@ +/* eslint-disable ts/no-explicit-any */ +'use client' +import type { FC } from 'react' +import type { Features as FeaturesData, FileUpload } from '@/app/components/base/features/types' +import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' +import type { ModelConfig } from '@/models/debug' +import type { ModelConfig as BackendModelConfig, PromptVariable } from '@/types/app' +import { noop } from 'es-toolkit/function' +import { clone } from 'es-toolkit/object' +import * as React from 'react' +import { useMemo, useState } from 'react' +import Config from '@/app/components/app/configuration/config' +import Debug from '@/app/components/app/configuration/debug' +import { FeaturesProvider } from '@/app/components/base/features' +import Loading from '@/app/components/base/loading' +import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' +import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' +import ConfigContext from '@/context/debug-configuration' +import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { PromptMode } from '@/models/debug' +import { useAllToolProviders } from '@/service/use-tools' +import { useGetTryAppDataSets, useGetTryAppInfo } from '@/service/use-try-app' +import { ModelModeType, Resolution, TransferMethod } from '@/types/app' +import { correctModelProvider, correctToolProvider } from '@/utils' +import { userInputsFormToPromptVariables } from '@/utils/model-config' +import { basePath } from '@/utils/var' +import { useTextGenerationCurrentProviderAndModelAndModelList } from '../../../header/account-setting/model-provider-page/hooks' + +type Props = { + appId: string +} + +const defaultModelConfig = { + provider: 'langgenius/openai/openai', + model_id: 'gpt-3.5-turbo', + mode: ModelModeType.unset, + configs: { + prompt_template: '', + prompt_variables: [] as PromptVariable[], + }, + more_like_this: null, + opening_statement: '', + suggested_questions: [], + sensitive_word_avoidance: null, + speech_to_text: null, + text_to_speech: null, + file_upload: null, + suggested_questions_after_answer: null, + retriever_resource: null, + annotation_reply: null, + dataSets: [], + agentConfig: DEFAULT_AGENT_SETTING, +} +const BasicAppPreview: FC = ({ + appId, +}) => { + const media = useBreakpoints() + const isMobile = media === MediaType.mobile + + const { data: appDetail, isLoading: isLoadingAppDetail } = useGetTryAppInfo(appId) + const { data: collectionListFromServer, isLoading: isLoadingToolProviders } = useAllToolProviders() + const collectionList = collectionListFromServer?.map((item) => { + return { + ...item, + icon: basePath && typeof item.icon == 'string' && !item.icon.includes(basePath) ? `${basePath}${item.icon}` : item.icon, + } + }) + const datasetIds = (() => { + if (isLoadingAppDetail) + return [] + const modelConfig = appDetail?.model_config + if (!modelConfig) + return [] + let datasets: any = null + + if (modelConfig.agent_mode?.tools?.find(({ dataset }: any) => dataset?.enabled)) + datasets = modelConfig.agent_mode?.tools.filter(({ dataset }: any) => dataset?.enabled) + // new dataset struct + else if (modelConfig.dataset_configs.datasets?.datasets?.length > 0) + datasets = modelConfig.dataset_configs?.datasets?.datasets + + if (datasets?.length && datasets?.length > 0) + return datasets.map(({ dataset }: any) => dataset.id) + + return [] + })() + const { data: dataSetData, isLoading: isLoadingDatasets } = useGetTryAppDataSets(appId, datasetIds) + const dataSets = dataSetData?.data || [] + const isLoading = isLoadingAppDetail || isLoadingDatasets || isLoadingToolProviders + + const modelConfig: ModelConfig = ((modelConfig?: BackendModelConfig) => { + if (isLoading || !modelConfig) + return defaultModelConfig + + const model = modelConfig.model + + const newModelConfig = { + provider: correctModelProvider(model.provider), + model_id: model.name, + mode: model.mode, + configs: { + prompt_template: modelConfig.pre_prompt || '', + prompt_variables: userInputsFormToPromptVariables( + [ + ...(modelConfig.user_input_form as any), + ...( + modelConfig.external_data_tools?.length + ? modelConfig.external_data_tools.map((item) => { + return { + external_data_tool: { + variable: item.variable as string, + label: item.label as string, + enabled: item.enabled, + type: item.type as string, + config: item.config, + required: true, + icon: item.icon, + icon_background: item.icon_background, + }, + } + }) + : [] + ), + ], + modelConfig.dataset_query_variable, + ), + }, + more_like_this: modelConfig.more_like_this, + opening_statement: modelConfig.opening_statement, + suggested_questions: modelConfig.suggested_questions, + sensitive_word_avoidance: modelConfig.sensitive_word_avoidance, + speech_to_text: modelConfig.speech_to_text, + text_to_speech: modelConfig.text_to_speech, + file_upload: modelConfig.file_upload, + suggested_questions_after_answer: modelConfig.suggested_questions_after_answer, + retriever_resource: modelConfig.retriever_resource, + annotation_reply: modelConfig.annotation_reply, + external_data_tools: modelConfig.external_data_tools, + dataSets, + agentConfig: appDetail?.mode === 'agent-chat' + // eslint-disable-next-line style/multiline-ternary + ? ({ + max_iteration: DEFAULT_AGENT_SETTING.max_iteration, + ...modelConfig.agent_mode, + // remove dataset + enabled: true, // modelConfig.agent_mode?.enabled is not correct. old app: the value of app with dataset's is always true + tools: modelConfig.agent_mode?.tools.filter((tool: any) => { + return !tool.dataset + }).map((tool: any) => { + const toolInCollectionList = collectionList?.find(c => tool.provider_id === c.id) + return { + ...tool, + isDeleted: appDetail?.deleted_tools?.some((deletedTool: any) => deletedTool.id === tool.id && deletedTool.tool_name === tool.tool_name), + notAuthor: toolInCollectionList?.is_team_authorization === false, + ...(tool.provider_type === 'builtin' + ? { + provider_id: correctToolProvider(tool.provider_name, !!toolInCollectionList), + provider_name: correctToolProvider(tool.provider_name, !!toolInCollectionList), + } + : {}), + } + }), + }) : DEFAULT_AGENT_SETTING, + } + return (newModelConfig as any) + })(appDetail?.model_config) + const mode = appDetail?.mode + // const isChatApp = ['chat', 'advanced-chat', 'agent-chat'].includes(mode!) + + // chat configuration + const promptMode = modelConfig?.prompt_type === PromptMode.advanced ? PromptMode.advanced : PromptMode.simple + const isAdvancedMode = promptMode === PromptMode.advanced + const isAgent = mode === 'agent-chat' + const chatPromptConfig = isAdvancedMode ? (modelConfig?.chat_prompt_config || clone(DEFAULT_CHAT_PROMPT_CONFIG)) : undefined + const suggestedQuestions = modelConfig?.suggested_questions || [] + const moreLikeThisConfig = modelConfig?.more_like_this || { enabled: false } + const suggestedQuestionsAfterAnswerConfig = modelConfig?.suggested_questions_after_answer || { enabled: false } + const speechToTextConfig = modelConfig?.speech_to_text || { enabled: false } + const textToSpeechConfig = modelConfig?.text_to_speech || { enabled: false, voice: '', language: '' } + const citationConfig = modelConfig?.retriever_resource || { enabled: false } + const annotationConfig = modelConfig?.annotation_reply || { + id: '', + enabled: false, + score_threshold: ANNOTATION_DEFAULT.score_threshold, + embedding_model: { + embedding_provider_name: '', + embedding_model_name: '', + }, + } + const moderationConfig = modelConfig?.sensitive_word_avoidance || { enabled: false } + // completion configuration + const completionPromptConfig = modelConfig?.completion_prompt_config || clone(DEFAULT_COMPLETION_PROMPT_CONFIG) as any + + // prompt & model config + const inputs = {} + const query = '' + const completionParams = useState({}) + + const { + currentModel: currModel, + } = useTextGenerationCurrentProviderAndModelAndModelList( + { + provider: modelConfig.provider, + model: modelConfig.model_id, + }, + ) + + const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision) + const isShowDocumentConfig = !!currModel?.features?.includes(ModelFeatureEnum.document) + const isShowAudioConfig = !!currModel?.features?.includes(ModelFeatureEnum.audio) + const isAllowVideoUpload = !!currModel?.features?.includes(ModelFeatureEnum.video) + const visionConfig = { + enabled: false, + number_limits: 2, + detail: Resolution.low, + transfer_methods: [TransferMethod.local_file], + } + + const featuresData: FeaturesData = useMemo(() => { + return { + moreLikeThis: modelConfig.more_like_this || { enabled: false }, + opening: { + enabled: !!modelConfig.opening_statement, + opening_statement: modelConfig.opening_statement || '', + suggested_questions: modelConfig.suggested_questions || [], + }, + moderation: modelConfig.sensitive_word_avoidance || { enabled: false }, + speech2text: modelConfig.speech_to_text || { enabled: false }, + text2speech: modelConfig.text_to_speech || { enabled: false }, + file: { + image: { + detail: modelConfig.file_upload?.image?.detail || Resolution.high, + enabled: !!modelConfig.file_upload?.image?.enabled, + number_limits: modelConfig.file_upload?.image?.number_limits || 3, + transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled), + allowed_file_types: modelConfig.file_upload?.allowed_file_types || [], + allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`), + allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3, + fileUploadConfig: {}, + } as FileUpload, + suggested: modelConfig.suggested_questions_after_answer || { enabled: false }, + citation: modelConfig.retriever_resource || { enabled: false }, + annotationReply: modelConfig.annotation_reply || { enabled: false }, + } + }, [modelConfig]) + + if (isLoading) { + return ( +
+ +
+ ) + } + const value = { + readonly: true, + appId, + isAPIKeySet: true, + isTrailFinished: false, + mode, + modelModeType: '', + promptMode, + isAdvancedMode, + isAgent, + isOpenAI: false, + isFunctionCall: false, + collectionList: [], + setPromptMode: noop, + canReturnToSimpleMode: false, + setCanReturnToSimpleMode: noop, + chatPromptConfig, + completionPromptConfig, + currentAdvancedPrompt: '', + setCurrentAdvancedPrompt: noop, + conversationHistoriesRole: completionPromptConfig.conversation_histories_role, + showHistoryModal: false, + setConversationHistoriesRole: noop, + hasSetBlockStatus: true, + conversationId: '', + introduction: '', + setIntroduction: noop, + suggestedQuestions, + setSuggestedQuestions: noop, + setConversationId: noop, + controlClearChatMessage: false, + setControlClearChatMessage: noop, + prevPromptConfig: {}, + setPrevPromptConfig: noop, + moreLikeThisConfig, + setMoreLikeThisConfig: noop, + suggestedQuestionsAfterAnswerConfig, + setSuggestedQuestionsAfterAnswerConfig: noop, + speechToTextConfig, + setSpeechToTextConfig: noop, + textToSpeechConfig, + setTextToSpeechConfig: noop, + citationConfig, + setCitationConfig: noop, + annotationConfig, + setAnnotationConfig: noop, + moderationConfig, + setModerationConfig: noop, + externalDataToolsConfig: {}, + setExternalDataToolsConfig: noop, + formattingChanged: false, + setFormattingChanged: noop, + inputs, + setInputs: noop, + query, + setQuery: noop, + completionParams, + setCompletionParams: noop, + modelConfig, + setModelConfig: noop, + showSelectDataSet: noop, + dataSets, + setDataSets: noop, + datasetConfigs: [], + datasetConfigsRef: {}, + setDatasetConfigs: noop, + hasSetContextVar: true, + isShowVisionConfig, + visionConfig, + setVisionConfig: noop, + isAllowVideoUpload, + isShowDocumentConfig, + isShowAudioConfig, + rerankSettingModalOpen: false, + setRerankSettingModalOpen: noop, + } + return ( + + +
+
+
+ +
+ {!isMobile && ( +
+
+ +
+
+ )} +
+
+
+
+ ) +} +export default React.memo(BasicAppPreview) diff --git a/web/app/components/explore/try-app/preview/flow-app-preview.tsx b/web/app/components/explore/try-app/preview/flow-app-preview.tsx new file mode 100644 index 0000000000..ba64aecfba --- /dev/null +++ b/web/app/components/explore/try-app/preview/flow-app-preview.tsx @@ -0,0 +1,39 @@ +'use client' +import type { FC } from 'react' +import * as React from 'react' +import Loading from '@/app/components/base/loading' +import WorkflowPreview from '@/app/components/workflow/workflow-preview' +import { useGetTryAppFlowPreview } from '@/service/use-try-app' +import { cn } from '@/utils/classnames' + +type Props = { + appId: string + className?: string +} + +const FlowAppPreview: FC = ({ + appId, + className, +}) => { + const { data, isLoading } = useGetTryAppFlowPreview(appId) + + if (isLoading) { + return ( +
+ +
+ ) + } + if (!data) + return null + return ( +
+ +
+ ) +} +export default React.memo(FlowAppPreview) diff --git a/web/app/components/explore/try-app/preview/index.tsx b/web/app/components/explore/try-app/preview/index.tsx new file mode 100644 index 0000000000..a0c5fdc594 --- /dev/null +++ b/web/app/components/explore/try-app/preview/index.tsx @@ -0,0 +1,25 @@ +'use client' +import type { FC } from 'react' +import type { TryAppInfo } from '@/service/try-app' +import * as React from 'react' +import BasicAppPreview from './basic-app-preview' +import FlowAppPreview from './flow-app-preview' + +type Props = { + appId: string + appDetail: TryAppInfo +} + +const Preview: FC = ({ + appId, + appDetail, +}) => { + const isBasicApp = ['agent-chat', 'chat', 'completion'].includes(appDetail.mode) + + return ( +
+ {isBasicApp ? : } +
+ ) +} +export default React.memo(Preview) diff --git a/web/app/components/explore/try-app/tab.tsx b/web/app/components/explore/try-app/tab.tsx new file mode 100644 index 0000000000..75ba402204 --- /dev/null +++ b/web/app/components/explore/try-app/tab.tsx @@ -0,0 +1,37 @@ +'use client' +import type { FC } from 'react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import TabHeader from '../../base/tab-header' + +export enum TypeEnum { + TRY = 'try', + DETAIL = 'detail', +} + +type Props = { + value: TypeEnum + onChange: (value: TypeEnum) => void +} + +const Tab: FC = ({ + value, + onChange, +}) => { + const { t } = useTranslation() + const tabs = [ + { id: TypeEnum.TRY, name: t('tryApp.tabHeader.try', { ns: 'explore' }) }, + { id: TypeEnum.DETAIL, name: t('tryApp.tabHeader.detail', { ns: 'explore' }) }, + ] + return ( + void} + itemClassName="ml-0 system-md-semibold-uppercase" + itemWrapClassName="pt-2" + activeItemClassName="border-util-colors-blue-brand-blue-brand-500" + /> + ) +} +export default React.memo(Tab) diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index e16c00acd0..07dd0fca3d 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -137,7 +137,7 @@ export default function AppSelector() { diff --git a/web/app/components/header/account-setting/api-based-extension-page/empty.tsx b/web/app/components/header/account-setting/api-based-extension-page/empty.tsx index 38525993fa..d75e66f8d0 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/empty.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/empty.tsx @@ -17,7 +17,7 @@ const Empty = () => {
{t('apiBasedExtension.title', { ns: 'common' })}
diff --git a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx index d3146d7baa..b04981bf3c 100644 --- a/web/app/components/header/account-setting/api-based-extension-page/modal.tsx +++ b/web/app/components/header/account-setting/api-based-extension-page/modal.tsx @@ -30,7 +30,7 @@ const ApiBasedExtensionModal: FC = ({ onSave, }) => { const { t } = useTranslation() - const docLink = useDocLink('https://docs.dify.ai/versions/3-0-x') + const docLink = useDocLink() const [localeData, setLocaleData] = useState(data) const [loading, setLoading] = useState(false) const { notify } = useToastContext() @@ -102,7 +102,7 @@ const ApiBasedExtensionModal: FC = ({
{t('detailPanel.endpointsTip', { ns: 'plugin' })}
diff --git a/web/app/components/plugins/plugin-page/debug-info.tsx b/web/app/components/plugins/plugin-page/debug-info.tsx index f62f8a4134..f3eed424f4 100644 --- a/web/app/components/plugins/plugin-page/debug-info.tsx +++ b/web/app/components/plugins/plugin-page/debug-info.tsx @@ -8,8 +8,7 @@ import * as React from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Tooltip from '@/app/components/base/tooltip' -import { getDocsUrl } from '@/app/components/plugins/utils' -import { useLocale } from '@/context/i18n' +import { useDocLink } from '@/context/i18n' import { useDebugKey } from '@/service/use-plugins' import KeyValueItem from '../base/key-value-item' @@ -17,7 +16,7 @@ const i18nPrefix = 'debugInfo' const DebugInfo: FC = () => { const { t } = useTranslation() - const locale = useLocale() + const docLink = useDocLink() const { data: info, isLoading } = useDebugKey() // info.key likes 4580bdb7-b878-471c-a8a4-bfd760263a53 mask the middle part using *. @@ -34,7 +33,7 @@ const DebugInfo: FC = () => { <>
{t(`${i18nPrefix}.title`, { ns: 'plugin' })} - + {t(`${i18nPrefix}.viewDocs`, { ns: 'plugin' })} diff --git a/web/app/components/plugins/plugin-page/index.spec.tsx b/web/app/components/plugins/plugin-page/index.spec.tsx index a3ea7f7125..9b7ada2a87 100644 --- a/web/app/components/plugins/plugin-page/index.spec.tsx +++ b/web/app/components/plugins/plugin-page/index.spec.tsx @@ -24,6 +24,7 @@ vi.mock('@/hooks/use-document-title', () => ({ vi.mock('@/context/i18n', () => ({ useLocale: () => 'en-US', + useDocLink: () => (path: string) => `https://docs.example.com${path}`, })) vi.mock('@/context/global-public-context', () => ({ diff --git a/web/app/components/plugins/plugin-page/index.tsx b/web/app/components/plugins/plugin-page/index.tsx index d852e4d0b8..efb665197a 100644 --- a/web/app/components/plugins/plugin-page/index.tsx +++ b/web/app/components/plugins/plugin-page/index.tsx @@ -15,10 +15,9 @@ import Button from '@/app/components/base/button' import TabSlider from '@/app/components/base/tab-slider' import Tooltip from '@/app/components/base/tooltip' import ReferenceSettingModal from '@/app/components/plugins/reference-setting-modal' -import { getDocsUrl } from '@/app/components/plugins/utils' import { MARKETPLACE_API_PREFIX, SUPPORT_INSTALL_LOCAL_FILE_EXTENSIONS } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' -import { useLocale } from '@/context/i18n' +import { useDocLink } from '@/context/i18n' import useDocumentTitle from '@/hooks/use-document-title' import { usePluginInstallation } from '@/hooks/use-query-params' import { fetchBundleInfoFromMarketPlace, fetchManifestFromMarketPlace } from '@/service/plugins' @@ -47,7 +46,7 @@ const PluginPage = ({ marketplace, }: PluginPageProps) => { const { t } = useTranslation() - const locale = useLocale() + const docLink = useDocLink() useDocumentTitle(t('metadata.title', { ns: 'plugin' })) // Use nuqs hook for installation state @@ -175,7 +174,7 @@ const PluginPage = ({