mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into feat/quota-icon
This commit is contained in:
commit
7a465c122a
|
|
@ -79,6 +79,29 @@ jobs:
|
|||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install web dependencies
|
||||
run: |
|
||||
cd web
|
||||
pnpm install --frozen-lockfile
|
||||
|
||||
- name: ESLint autofix
|
||||
run: |
|
||||
cd web
|
||||
pnpm lint:fix || true
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ jobs:
|
|||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check:tsgo
|
||||
run: pnpm run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
|
|
|||
|
|
@ -27,7 +27,9 @@ ignore_imports =
|
|||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
|
|
@ -57,6 +59,252 @@ ignore_imports =
|
|||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow
|
||||
forbidden_modules =
|
||||
configs
|
||||
controllers
|
||||
extensions
|
||||
models
|
||||
services
|
||||
tasks
|
||||
core.agent
|
||||
core.app
|
||||
core.base
|
||||
core.callback_handler
|
||||
core.datasource
|
||||
core.db
|
||||
core.entities
|
||||
core.errors
|
||||
core.extension
|
||||
core.external_data_tool
|
||||
core.file
|
||||
core.helper
|
||||
core.hosting_configuration
|
||||
core.indexing_runner
|
||||
core.llm_generator
|
||||
core.logging
|
||||
core.mcp
|
||||
core.memory
|
||||
core.model_manager
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
core.prompt
|
||||
core.provider_manager
|
||||
core.rag
|
||||
core.repositories
|
||||
core.schemas
|
||||
core.tools
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.layers.observability -> configs
|
||||
core.workflow.graph_engine.layers.observability -> extensions.otel.runtime
|
||||
core.workflow.graph_engine.layers.persistence -> core.ops.ops_trace_manager
|
||||
core.workflow.graph_engine.worker_management.worker_pool -> configs
|
||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
core.workflow.nodes.document_extractor.node -> configs
|
||||
core.workflow.nodes.document_extractor.node -> core.file.file_manager
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.entities -> configs
|
||||
core.workflow.nodes.http_request.executor -> configs
|
||||
core.workflow.nodes.http_request.executor -> core.file.file_manager
|
||||
core.workflow.nodes.http_request.node -> configs
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.file.models
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> models.model
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.template_transform.template_transform_node -> configs
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
core.workflow.workflow_entry -> configs
|
||||
core.workflow.workflow_entry -> models.workflow
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.entities
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
core.workflow.graph_engine.layers.persistence -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.start.entities -> core.app.app_config.entities
|
||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
core.workflow.node_events.node -> core.file
|
||||
core.workflow.nodes.agent.agent_node -> core.file
|
||||
core.workflow.nodes.datasource.datasource_node -> core.file
|
||||
core.workflow.nodes.datasource.datasource_node -> core.file.enums
|
||||
core.workflow.nodes.document_extractor.node -> core.file
|
||||
core.workflow.nodes.http_request.executor -> core.file.enums
|
||||
core.workflow.nodes.http_request.node -> core.file
|
||||
core.workflow.nodes.http_request.node -> core.file.file_manager
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.file.models
|
||||
core.workflow.nodes.list_operator.node -> core.file
|
||||
core.workflow.nodes.llm.file_saver -> core.file
|
||||
core.workflow.nodes.llm.llm_utils -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.file
|
||||
core.workflow.nodes.llm.node -> core.file.file_manager
|
||||
core.workflow.nodes.llm.node -> core.file.models
|
||||
core.workflow.nodes.loop.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.file
|
||||
core.workflow.nodes.protocols -> core.file
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.file.models
|
||||
core.workflow.nodes.tool.tool_node -> core.file
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.tool.tool_node -> models
|
||||
core.workflow.nodes.trigger_webhook.node -> core.file
|
||||
core.workflow.runtime.variable_pool -> core.file
|
||||
core.workflow.runtime.variable_pool -> core.file.file_manager
|
||||
core.workflow.system_variable -> core.file.models
|
||||
core.workflow.utils.condition.processor -> core.file
|
||||
core.workflow.utils.condition.processor -> core.file.file_manager
|
||||
core.workflow.workflow_entry -> core.file.models
|
||||
core.workflow.workflow_type_encoder -> core.file.models
|
||||
core.workflow.nodes.agent.agent_node -> models.model
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_node_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
|
||||
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
|
||||
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.node -> core.helper.code_executor
|
||||
core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.graph_engine.layers.persistence -> core.ops.entities.trace_entity
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.errors
|
||||
core.workflow.conversation_variable_updater -> core.variables
|
||||
core.workflow.graph_engine.entities.commands -> core.variables.variables
|
||||
core.workflow.nodes.agent.agent_node -> core.variables.segments
|
||||
core.workflow.nodes.answer.answer_node -> core.variables
|
||||
core.workflow.nodes.code.code_node -> core.variables.segments
|
||||
core.workflow.nodes.code.code_node -> core.variables.types
|
||||
core.workflow.nodes.code.entities -> core.variables.types
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
|
||||
core.workflow.nodes.document_extractor.node -> core.variables
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.variables.segments
|
||||
core.workflow.nodes.list_operator.node -> core.variables
|
||||
core.workflow.nodes.list_operator.node -> core.variables.segments
|
||||
core.workflow.nodes.llm.node -> core.variables
|
||||
core.workflow.nodes.loop.loop_node -> core.variables
|
||||
core.workflow.nodes.parameter_extractor.entities -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.exc -> core.variables.types
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.variables.types
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.segments
|
||||
core.workflow.nodes.tool.tool_node -> core.variables.variables
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.types
|
||||
core.workflow.nodes.trigger_webhook.node -> core.variables.variables
|
||||
core.workflow.nodes.variable_aggregator.entities -> core.variables.types
|
||||
core.workflow.nodes.variable_aggregator.variable_aggregator_node -> core.variables.segments
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.consts
|
||||
core.workflow.nodes.variable_assigner.common.helpers -> core.variables.types
|
||||
core.workflow.nodes.variable_assigner.v1.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.helpers -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables
|
||||
core.workflow.nodes.variable_assigner.v2.node -> core.variables.consts
|
||||
core.workflow.runtime.graph_runtime_state_protocol -> core.variables.segments
|
||||
core.workflow.runtime.read_only_wrappers -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables
|
||||
core.workflow.runtime.variable_pool -> core.variables.consts
|
||||
core.workflow.runtime.variable_pool -> core.variables.segments
|
||||
core.workflow.runtime.variable_pool -> core.variables.variables
|
||||
core.workflow.utils.condition.processor -> core.variables
|
||||
core.workflow.utils.condition.processor -> core.variables.segments
|
||||
core.workflow.variable_loader -> core.variables
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider_ids
|
||||
core.workflow.nodes.llm.node -> models.model
|
||||
core.workflow.workflow_entry -> models.enums
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
core.workflow.nodes.tool.tool_node -> services
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
|
|
|
|||
171
api/README.md
171
api/README.md
|
|
@ -1,6 +1,6 @@
|
|||
# Dify Backend API
|
||||
|
||||
## Usage
|
||||
## Setup and Run
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
|
|
@ -8,48 +8,77 @@
|
|||
> [`uv`](https://docs.astral.sh/uv/) as the package manager
|
||||
> for Dify API backend service.
|
||||
|
||||
1. Start the docker-compose stack
|
||||
`uv` and `pnpm` are required to run the setup and development commands below.
|
||||
|
||||
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
### Using scripts (recommended)
|
||||
|
||||
The scripts resolve paths relative to their location, so you can run them from anywhere.
|
||||
|
||||
1. Run setup (copies env files and installs dependencies).
|
||||
|
||||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
cd ../api
|
||||
./dev/setup
|
||||
```
|
||||
|
||||
1. Copy `.env.example` to `.env`
|
||||
1. Review `api/.env`, `web/.env.local`, and `docker/middleware.env` values (see the `SECRET_KEY` note below).
|
||||
|
||||
```cli
|
||||
cp .env.example .env
|
||||
1. Start middleware (PostgreSQL/Redis/Weaviate).
|
||||
|
||||
```bash
|
||||
./dev/start-docker-compose
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
1. Start backend (runs migrations first).
|
||||
|
||||
1. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
|
||||
```bash for Linux
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```bash
|
||||
./dev/start-api
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
1. Start Dify [web](../web) service.
|
||||
|
||||
```bash for Mac
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
SECRET_KEY=${secret_key}" .env
|
||||
```bash
|
||||
./dev/start-web
|
||||
```
|
||||
|
||||
1. Create environment.
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
|
||||
First, you need to add the uv package manager, if you don't have it already.
|
||||
1. Optional: start the worker service (async tasks, runs from `api`).
|
||||
|
||||
```bash
|
||||
./dev/start-worker
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks).
|
||||
|
||||
```bash
|
||||
./dev/start-beat
|
||||
```
|
||||
|
||||
### Manual commands
|
||||
|
||||
<details>
|
||||
<summary>Show manual setup and run steps</summary>
|
||||
|
||||
These commands assume you start from the repository root.
|
||||
|
||||
1. Start the docker-compose stack.
|
||||
|
||||
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
```bash
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
# Use mysql or another vector database profile if you are not using postgres/weaviate.
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
```
|
||||
|
||||
1. Copy env files.
|
||||
|
||||
```bash
|
||||
cp api/.env.example api/.env
|
||||
cp web/.env.example web/.env.local
|
||||
```
|
||||
|
||||
1. Install UV if needed.
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
|
|
@ -57,60 +86,96 @@
|
|||
brew install uv
|
||||
```
|
||||
|
||||
1. Install dependencies
|
||||
1. Install API dependencies.
|
||||
|
||||
```bash
|
||||
uv sync --dev
|
||||
cd api
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
1. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
1. Install web dependencies.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm install
|
||||
cd ..
|
||||
```
|
||||
|
||||
1. Start backend (runs migrations first, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run flask db upgrade
|
||||
```
|
||||
|
||||
1. Start backend
|
||||
|
||||
```bash
|
||||
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
|
||||
1. Start Dify [web](../web) service.
|
||||
1. Start Dify [web](../web) service (in a new terminal).
|
||||
|
||||
1. Setup your application by visiting `http://localhost:3000`.
|
||||
```bash
|
||||
cd web
|
||||
pnpm dev:inspect
|
||||
```
|
||||
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
1. Optional: start the worker service (async tasks, in a new terminal).
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Environment notes
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
|
||||
- Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
|
||||
```bash
|
||||
sed -i "/^SECRET_KEY=/c\\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
|
||||
bash for Mac
|
||||
|
||||
```bash
|
||||
secret_key=$(openssl rand -base64 42)
|
||||
sed -i '' "/^SECRET_KEY=/c\\
|
||||
SECRET_KEY=${secret_key}" .env
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
1. Install dependencies for both the backend and the test environment
|
||||
|
||||
```bash
|
||||
uv sync --dev
|
||||
cd api
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run pytest # Run all tests
|
||||
uv run pytest tests/unit_tests/ # Unit tests only
|
||||
uv run pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_fastopenapi,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
|
|
@ -128,6 +129,7 @@ def initialize_extensions(app: DifyApp):
|
|||
ext_proxy_fix,
|
||||
ext_blueprints,
|
||||
ext_commands,
|
||||
ext_fastopenapi,
|
||||
ext_otel,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
|
|
|
|||
340
api/commands.py
340
api/commands.py
|
|
@ -950,6 +950,346 @@ def clean_workflow_runs(
|
|||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"archive-workflow-runs",
|
||||
help="Archive workflow runs for paid plan tenants to S3-compatible storage.",
|
||||
)
|
||||
@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.")
|
||||
@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.")
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--to-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Archive runs created at or after this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Archive runs created before this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.")
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.")
|
||||
@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without archiving.")
|
||||
@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.")
|
||||
def archive_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
before_days: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
batch_size: int,
|
||||
workers: int,
|
||||
limit: int | None,
|
||||
dry_run: bool,
|
||||
delete_after_archive: bool,
|
||||
):
|
||||
"""
|
||||
Archive workflow runs for paid plan tenants older than the specified days.
|
||||
|
||||
This command archives the following tables to storage:
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
|
||||
The workflow_runs and workflow_app_logs tables are preserved for UI listing.
|
||||
"""
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
|
||||
|
||||
run_started_at = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting workflow run archiving at {run_started_at.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
click.echo(click.style("start-from and end-before must be provided together.", fg="red"))
|
||||
return
|
||||
|
||||
if (from_days_ago is None) ^ (to_days_ago is None):
|
||||
click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red"))
|
||||
return
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
if start_from or end_before:
|
||||
click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red"))
|
||||
return
|
||||
if from_days_ago <= to_days_ago:
|
||||
click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red"))
|
||||
return
|
||||
now = datetime.datetime.now()
|
||||
start_from = now - datetime.timedelta(days=from_days_ago)
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if start_from and end_before and start_from >= end_before:
|
||||
click.echo(click.style("start-from must be earlier than end-before.", fg="red"))
|
||||
return
|
||||
if workers < 1:
|
||||
click.echo(click.style("workers must be at least 1.", fg="red"))
|
||||
return
|
||||
|
||||
archiver = WorkflowRunArchiver(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
workers=workers,
|
||||
tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None,
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
delete_after_archive=delete_after_archive,
|
||||
)
|
||||
summary = archiver.run()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
|
||||
f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
|
||||
f"time={summary.total_elapsed_time:.2f}s",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
|
||||
run_finished_at = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = run_finished_at - run_started_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Workflow run archiving completed. start={run_started_at.isoformat()} "
|
||||
f"end={run_finished_at.isoformat()} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"restore-workflow-runs",
|
||||
help="Restore archived workflow runs from S3-compatible storage.",
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-ids",
|
||||
required=False,
|
||||
help="Tenant IDs (comma-separated).",
|
||||
)
|
||||
@click.option("--run-id", required=False, help="Workflow run ID to restore.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.")
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without restoring.")
|
||||
def restore_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
run_id: str | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
workers: int,
|
||||
limit: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Restore an archived workflow run from storage to the database.
|
||||
|
||||
This restores the following tables:
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
"""
|
||||
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
|
||||
|
||||
parsed_tenant_ids = None
|
||||
if tenant_ids:
|
||||
parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
|
||||
if not parsed_tenant_ids:
|
||||
raise click.BadParameter("tenant-ids must not be empty")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
if run_id is None and (start_from is None or end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before are required for batch restore.")
|
||||
if workers < 1:
|
||||
raise click.BadParameter("workers must be at least 1")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers)
|
||||
if run_id:
|
||||
results = [restorer.restore_by_run_id(run_id)]
|
||||
else:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
results = restorer.restore_batch(
|
||||
parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"delete-archived-workflow-runs",
|
||||
help="Delete archived workflow runs from the database.",
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-ids",
|
||||
required=False,
|
||||
help="Tenant IDs (comma-separated).",
|
||||
)
|
||||
@click.option("--run-id", required=False, help="Workflow run ID to delete.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without deleting.")
|
||||
def delete_archived_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
run_id: str | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
limit: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Delete archived workflow runs from the database.
|
||||
"""
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
parsed_tenant_ids = None
|
||||
if tenant_ids:
|
||||
parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
|
||||
if not parsed_tenant_ids:
|
||||
raise click.BadParameter("tenant-ids must not be empty")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
if run_id is None and (start_from is None or end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before are required for batch delete.")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
target_desc = f"workflow run {run_id}" if run_id else "workflow runs"
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting delete of {target_desc} at {start_time.isoformat()}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run)
|
||||
if run_id:
|
||||
results = [deleter.delete_by_run_id(run_id)]
|
||||
else:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
results = deleter.delete_batch(
|
||||
parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
for result in results:
|
||||
if result.success:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} "
|
||||
f"workflow run {result.run_id} (tenant={result.tenant_id})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to delete workflow run {result.run_id}: {result.error}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
|
||||
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
|
||||
def clear_orphaned_file_records(force: bool):
|
||||
|
|
|
|||
|
|
@ -965,6 +965,16 @@ class MailConfig(BaseSettings):
|
|||
default=None,
|
||||
)
|
||||
|
||||
ENABLE_TRIAL_APP: bool = Field(
|
||||
description="Enable trial app",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENABLE_EXPLORE_BANNER: bool = Field(
|
||||
description="Enable explore banner",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ Flask App Context - Flask implementation of AppContext interface.
|
|||
"""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final
|
||||
|
|
@ -118,6 +119,7 @@ class FlaskExecutionContext:
|
|||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
self._flask_app = flask_app
|
||||
self._local = threading.local()
|
||||
|
||||
@property
|
||||
def app_context(self) -> FlaskAppContext:
|
||||
|
|
@ -136,47 +138,39 @@ class FlaskExecutionContext:
|
|||
|
||||
def __enter__(self) -> "FlaskExecutionContext":
|
||||
"""Enter the Flask execution context."""
|
||||
# Restore context variables
|
||||
# Restore non-Flask context variables to avoid leaking Flask tokens across threads
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
self._cm = self._app_context.enter()
|
||||
self._cm.__enter__()
|
||||
cm = self._app_context.enter()
|
||||
self._local.cm = cm
|
||||
cm.__enter__()
|
||||
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
if self._user is not None:
|
||||
g._login_user = self._user
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the Flask execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
cm = getattr(self._local, "cm", None)
|
||||
if cm is not None:
|
||||
cm.__exit__(*args)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask execution context as context manager."""
|
||||
# Restore context variables
|
||||
# Restore non-Flask context variables to avoid leaking Flask tokens across threads
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
with self._flask_app.app_context():
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
if self._user is not None:
|
||||
g._login_user = self._user
|
||||
yield
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -107,10 +107,12 @@ from .datasets.rag_pipeline import (
|
|||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
installed_app,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
|
|
@ -145,6 +147,7 @@ __all__ = [
|
|||
"apikey",
|
||||
"app",
|
||||
"audio",
|
||||
"banner",
|
||||
"billing",
|
||||
"bp",
|
||||
"completion",
|
||||
|
|
@ -198,6 +201,7 @@ __all__ = [
|
|||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trial",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
|
|||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
|
@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel):
|
|||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
can_trial: bool = Field(default=False)
|
||||
trial_limit: int = Field(default=0)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
|
|
@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel):
|
|||
return supported_language(value)
|
||||
|
||||
|
||||
class InsertExploreBannerPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
title: str = Field(...)
|
||||
description: str = Field(...)
|
||||
img_src: str = Field(..., alias="img-src")
|
||||
language: str = Field(default="en-US")
|
||||
link: str = Field(...)
|
||||
sort: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreAppPayload.__name__,
|
||||
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
InsertExploreBannerPayload.__name__,
|
||||
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
|
|
@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource):
|
|||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
|
|
@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource):
|
|||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
|
|
@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource):
|
|||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBannerApi(Resource):
|
||||
@console_ns.doc("insert_explore_banner")
|
||||
@console_ns.doc(description="Insert an explore banner")
|
||||
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
|
||||
@console_ns.response(201, "Banner inserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
content = {
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
}
|
||||
|
||||
banner = ExporleBanner(
|
||||
content=content,
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBannerApi(Resource):
|
||||
@console_ns.doc("delete_explore_banner")
|
||||
@console_ns.doc(description="Delete an explore banner")
|
||||
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
|
||||
@console_ns.response(204, "Banner deleted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
|
|
|||
|
|
@ -82,13 +82,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
|||
class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_exist"
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 400
|
||||
code = 404
|
||||
|
||||
|
||||
class DraftWorkflowNotSync(BaseHTTPException):
|
||||
error_code = "draft_workflow_not_sync"
|
||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||
code = 400
|
||||
code = 409
|
||||
|
||||
|
||||
class TracingConfigNotExist(BaseHTTPException):
|
||||
|
|
@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
|
|||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class NeedAddIdsError(BaseHTTPException):
|
||||
error_code = "need_add_ids"
|
||||
description = "Need to add ids."
|
||||
code = 400
|
||||
|
|
|
|||
|
|
@ -470,7 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
|
|
@ -508,7 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
|
|
@ -999,6 +999,7 @@ class DraftWorkflowTriggerRunApi(Resource):
|
|||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
workflow_args = dict(event.workflow_args)
|
||||
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
return helper.compact_generate_response(
|
||||
AppGenerateService.generate(
|
||||
|
|
@ -1147,6 +1148,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
|||
|
||||
try:
|
||||
workflow_args = dict(trigger_debug_event.workflow_args)
|
||||
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,10 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from fields.workflow_app_log_fields import (
|
||||
build_workflow_app_log_pagination_model,
|
||||
build_workflow_archived_log_pagination_model,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
|
@ -61,6 +64,7 @@ console_ns.schema_model(
|
|||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
|
|
@ -99,3 +103,33 @@ class WorkflowAppLogApi(Resource):
|
|||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-archived-logs")
|
||||
class WorkflowArchivedLogApi(Resource):
|
||||
@console_ns.doc("get_workflow_archived_logs")
|
||||
@console_ns.doc(description="Get workflow archived execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_archived_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow archived logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
|
|
@ -19,14 +22,17 @@ from fields.workflow_run_fields import (
|
|||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
|
@ -93,6 +99,15 @@ workflow_run_node_execution_list_model = console_ns.model(
|
|||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_export_fields = console_ns.model(
|
||||
"WorkflowRunExport",
|
||||
{
|
||||
"status": fields.String(description="Export status: success/failed"),
|
||||
"presigned_url": fields.String(description="Pre-signed URL for download", required=False),
|
||||
"presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False),
|
||||
},
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
|
|
@ -181,6 +196,56 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/export")
|
||||
class WorkflowRunExportApi(Resource):
|
||||
@console_ns.doc("get_workflow_run_export_url")
|
||||
@console_ns.doc(description="Generate a download URL for an archived workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Export URL generated", workflow_run_export_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: str):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
run_id_str = str(run_id)
|
||||
|
||||
run_created_at = db.session.scalar(
|
||||
select(WorkflowArchiveLog.run_created_at)
|
||||
.where(
|
||||
WorkflowArchiveLog.tenant_id == tenant_id,
|
||||
WorkflowArchiveLog.app_id == app_id,
|
||||
WorkflowArchiveLog.workflow_run_id == run_id_str,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if not run_created_at:
|
||||
return {"code": "archive_log_not_found", "message": "workflow run archive not found"}, 404
|
||||
|
||||
prefix = (
|
||||
f"{tenant_id}/app_id={app_id}/year={run_created_at.strftime('%Y')}/"
|
||||
f"month={run_created_at.strftime('%m')}/workflow_run_id={run_id_str}"
|
||||
)
|
||||
archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
|
||||
|
||||
try:
|
||||
archive_storage = get_archive_storage()
|
||||
except ArchiveStorageNotConfiguredError as e:
|
||||
return {"code": "archive_storage_not_configured", "message": str(e)}, 500
|
||||
|
||||
presigned_url = archive_storage.generate_presigned_url(
|
||||
archive_key,
|
||||
expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS,
|
||||
)
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS)
|
||||
return {
|
||||
"status": "success",
|
||||
"presigned_url": presigned_url,
|
||||
"presigned_url_expires_at": expires_at.isoformat(),
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs_count")
|
||||
|
|
|
|||
|
|
@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None:
|
|||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@wraps(view_func)
|
||||
|
|
@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
|||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
app_id = kwargs.get("app_id")
|
||||
app_id = str(app_id)
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = _load_app_model_with_trial(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
if mode is not None:
|
||||
if isinstance(mode, list):
|
||||
modes = mode
|
||||
else:
|
||||
modes = [mode]
|
||||
|
||||
if app_mode not in modes:
|
||||
mode_values = {m.value for m in modes}
|
||||
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import explore_banner_enabled
|
||||
from extensions.ext_database import db
|
||||
from models.model import ExporleBanner
|
||||
|
||||
|
||||
class BannerApi(Resource):
|
||||
"""Resource for banner list."""
|
||||
|
||||
@explore_banner_enabled
|
||||
def get(self):
|
||||
"""Get banner list."""
|
||||
language = request.args.get("language", "en-US")
|
||||
|
||||
# Build base query for enabled banners
|
||||
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
|
||||
|
||||
# Try to get banners in the requested language
|
||||
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
|
||||
|
||||
# Fallback to en-US if no banners found and language is not en-US
|
||||
if not banners and language != "en-US":
|
||||
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
|
||||
# Convert banners to serializable format
|
||||
result = []
|
||||
for banner in banners:
|
||||
banner_data = {
|
||||
"id": banner.id,
|
||||
"content": banner.content, # Already parsed as JSON by SQLAlchemy
|
||||
"link": banner.link,
|
||||
"sort": banner.sort,
|
||||
"status": banner.status,
|
||||
"created_at": banner.created_at.isoformat() if banner.created_at else None,
|
||||
}
|
||||
result.append(banner_data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(BannerApi, "/explore/banners")
|
||||
|
|
@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
|
|||
error_code = "access_denied"
|
||||
description = "App access denied."
|
||||
code = 403
|
||||
|
||||
|
||||
class TrialAppNotAllowed(BaseHTTPException):
|
||||
"""*403* `Trial App Not Allowed`
|
||||
|
||||
Raise if the user has reached the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_not_allowed"
|
||||
code = 403
|
||||
description = "the app is not allowed to be trial."
|
||||
|
||||
|
||||
class TrialAppLimitExceeded(BaseHTTPException):
|
||||
"""*403* `Trial App Limit Exceeded`
|
||||
|
||||
Raise if the user has exceeded the trial app limit.
|
||||
"""
|
||||
|
||||
error_code = "trial_app_limit_exceeded"
|
||||
code = 403
|
||||
description = "The user has exceeded the trial app limit."
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ recommended_app_fields = {
|
|||
"category": fields.String,
|
||||
"position": fields.Integer,
|
||||
"is_listed": fields.Boolean,
|
||||
"can_trial": fields.Boolean,
|
||||
}
|
||||
|
||||
recommended_app_list_fields = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,512 @@
|
|||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
NeedAddIdsError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model_with_trial
|
||||
from controllers.console.explore.error import (
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotWorkflowAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields_with_site
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.account import TenantStatus
|
||||
from models.model import AppMode, Site
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_service import AppService
|
||||
from services.audio_service import AudioService
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import (
|
||||
MessageNotExistsError,
|
||||
SuggestedQuestionsAfterAnswerDisabledError,
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
def post(self, trial_app, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
app_model = trial_app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialMessageSuggestedQuestionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def get(self, trial_app, message_id):
|
||||
app_model = trial_app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation not found")
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
class TrialChatAudioApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
# Get IDs before they might be detached from session
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TrialSitApi(Resource):
|
||||
"""Resource for trial app sites."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
"""
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
assert app_model.tenant
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return SiteResponse.model_validate(site).model_dump(mode="json")
|
||||
|
||||
|
||||
class TrialAppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
"""Retrieve app parameters."""
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return ParametersResponse.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
class AppWorkflowApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
@marshal_with(workflow_fields)
|
||||
def get(self, app_model):
|
||||
"""Get workflow detail"""
|
||||
if not app_model.workflow_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
|
||||
class DatasetListApi(Resource):
|
||||
@trial_feature_enable
|
||||
@get_app_model_with_trial
|
||||
def get(self, app_model):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
|
||||
tenant_id = app_model.tenant_id
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
|
||||
else:
|
||||
raise NeedAddIdsError()
|
||||
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
|
||||
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
|
||||
|
||||
api.add_resource(
|
||||
TrialMessageSuggestedQuestionApi,
|
||||
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="trial_app_suggested_question",
|
||||
)
|
||||
|
||||
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
|
||||
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
|
||||
|
||||
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
|
||||
|
||||
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
|
||||
|
||||
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
|
||||
|
||||
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
|
||||
|
||||
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
|
||||
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
|
||||
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
|
||||
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")
|
||||
|
|
@ -2,14 +2,15 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import InstalledApp
|
||||
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
|
||||
|
||||
if trial_app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
app = trial_app.app
|
||||
|
||||
if app is None:
|
||||
raise TrialAppNotAllowed()
|
||||
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
if account_trial_app_record:
|
||||
if account_trial_app_record.count >= trial_app.trial_limit:
|
||||
raise TrialAppLimitExceeded()
|
||||
|
||||
return view(app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_trial_app:
|
||||
abort(403, "Trial app feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.enable_explore_banner:
|
||||
abort(403, "Explore banner feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class InstalledAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
|
|
@ -80,3 +136,13 @@ class InstalledAppResource(Resource):
|
|||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
||||
|
||||
class TrialAppResource(Resource):
|
||||
# must be reversed if there are multiple decorators
|
||||
|
||||
method_decorators = [
|
||||
trial_app_required,
|
||||
account_initialization_required,
|
||||
login_required,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import console_ns
|
||||
|
|
@ -39,5 +40,21 @@ class SystemFeatureApi(Resource):
|
|||
),
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration"""
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for dashboard initialization.
|
||||
|
||||
Authentication would create circular dependency (can't login without dashboard loading).
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
# NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated`
|
||||
# without a try-catch. However, due to the implementation of user loader (the `load_user_from_request`
|
||||
# in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will
|
||||
# raise `Unauthorized` exception if authentication token is not provided.
|
||||
try:
|
||||
is_authenticated = current_user.is_authenticated
|
||||
except Unauthorized:
|
||||
is_authenticated = False
|
||||
return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump()
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from . import console_ns
|
||||
from controllers.fastopenapi import console_router
|
||||
|
||||
|
||||
@console_ns.route("/ping")
|
||||
class PingApi(Resource):
|
||||
@console_ns.doc("health_check")
|
||||
@console_ns.doc(description="Health check endpoint for connection testing")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
|
||||
)
|
||||
def get(self):
|
||||
"""Health check endpoint for connection testing"""
|
||||
return {"result": "pong"}
|
||||
class PingResponse(BaseModel):
|
||||
result: str = Field(description="Health check result", examples=["pong"])
|
||||
|
||||
|
||||
@console_router.get(
|
||||
"/ping",
|
||||
response_model=PingResponse,
|
||||
tags=["console"],
|
||||
)
|
||||
def ping() -> PingResponse:
|
||||
"""Health check endpoint for connection testing."""
|
||||
return PingResponse(result="pong")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from fastopenapi.routers import FlaskRouter
|
||||
|
||||
console_router = FlaskRouter()
|
||||
|
|
@ -261,17 +261,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
|
|
@ -280,6 +269,18 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
|
@ -370,17 +371,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
|
|
@ -389,6 +379,18 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
|
|
|
|||
|
|
@ -17,5 +17,15 @@ class SystemFeatureApi(Resource):
|
|||
|
||||
Returns:
|
||||
dict: System feature configuration object
|
||||
|
||||
This endpoint is akin to the `SystemFeatureApi` endpoint in api/controllers/console/feature.py,
|
||||
except it is intended for use by the web app, instead of the console dashboard.
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design, as it provides system features
|
||||
data required for webapp initialization.
|
||||
|
||||
Authentication would create circular dependency (can't authenticate without webapp loading).
|
||||
|
||||
Only non-sensitive configuration data should be returned by this endpoint.
|
||||
"""
|
||||
return FeatureService.get_system_features().model_dump()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
|
|
@ -18,7 +20,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import email
|
||||
from libs.helper import EmailStr
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
|
|
@ -30,10 +32,35 @@ from services.app_service import AppService
|
|||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
class EmailCodeLoginSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class EmailCodeLoginVerifyPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
register_schema_models(web_ns, LoginPayload, EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload)
|
||||
|
||||
|
||||
@web_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
@web_ns.expect(web_ns.models[LoginPayload.__name__])
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
@web_ns.doc("web_app_login")
|
||||
|
|
@ -50,15 +77,10 @@ class LoginApi(Resource):
|
|||
@decrypt_password_field
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=valid_password, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
|
|
@ -145,6 +167,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
@only_edition_enterprise
|
||||
@web_ns.doc("send_email_code_login")
|
||||
@web_ns.doc(description="Send email verification code for login")
|
||||
@web_ns.expect(web_ns.models[EmailCodeLoginSendPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code sent successfully",
|
||||
|
|
@ -153,19 +176,14 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
}
|
||||
)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = EmailCodeLoginSendPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
if payload.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||
account = WebAppAuthService.get_user_through_email(payload.email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
|
|
@ -179,6 +197,7 @@ class EmailCodeLoginApi(Resource):
|
|||
@only_edition_enterprise
|
||||
@web_ns.doc("verify_email_code_login")
|
||||
@web_ns.doc(description="Verify email code and complete login")
|
||||
@web_ns.expect(web_ns.models[EmailCodeLoginVerifyPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Email code verified and login successful",
|
||||
|
|
@ -189,17 +208,11 @@ class EmailCodeLoginApi(Resource):
|
|||
)
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = EmailCodeLoginVerifyPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
user_email = args["email"].lower()
|
||||
user_email = payload.email.lower()
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
|
|
@ -210,10 +223,10 @@ class EmailCodeLoginApi(Resource):
|
|||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
if token_data["code"] != payload.code:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
if not account:
|
||||
raise AuthenticationFailedError()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import reqparse
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
|
|
@ -27,19 +29,22 @@ from models.model import App, AppMode, EndUser
|
|||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(web_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
@web_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(WebApiResource):
|
||||
@web_ns.doc("Run Workflow")
|
||||
@web_ns.doc(description="Execute a workflow with provided inputs and files.")
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the workflow", "type": "object", "required": True},
|
||||
"files": {"description": "Files to be processed by the workflow", "type": "array", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.expect(web_ns.models[WorkflowRunPayload.__name__])
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
|
|
@ -58,12 +63,8 @@ class WorkflowRunApi(WebApiResource):
|
|||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -13,6 +15,9 @@ from sqlalchemy.orm import Session, sessionmaker
|
|||
import contexts
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
|
|
@ -304,7 +309,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping,
|
||||
args: LoopNodeRunPayload,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
|
|
@ -320,7 +325,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
if args.inputs is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
|
|
@ -338,7 +343,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -40,6 +42,9 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
|
|||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -381,7 +386,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
args: LoopNodeRunPayload,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
|
|
@ -397,7 +402,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
if args.inputs is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
|
|
@ -413,7 +418,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args.inputs or {}),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
|
|
@ -53,7 +54,6 @@ from core.workflow.graph_events import (
|
|||
)
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
@ -166,18 +166,22 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
# Determine which type of single node execution and get graph/variable_pool
|
||||
if single_iteration_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
elif single_loop_run:
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
|
||||
|
|
@ -314,44 +318,6 @@ class WorkflowBasedAppRunner:
|
|||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
return self._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .node_factory import DifyNodeFactory
|
||||
|
||||
__all__ = ["DifyNodeFactory"]
|
||||
|
|
@ -15,6 +15,7 @@ from core.workflow.nodes.base.node import Node
|
|||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
|
|
@ -23,8 +24,6 @@ from core.workflow.nodes.template_transform.template_renderer import (
|
|||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from libs.typing import is_str, is_str_dict
|
||||
|
||||
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
|
@ -154,7 +154,7 @@ class IrisConnectionPool:
|
|||
# Add to cache to skip future checks
|
||||
self._schemas_initialized.add(schema)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
logger.exception("Failed to ensure schema %s exists", schema)
|
||||
raise
|
||||
|
|
@ -177,6 +177,9 @@ class IrisConnectionPool:
|
|||
class IrisVector(BaseVector):
|
||||
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
|
||||
|
||||
# Fallback score for full-text search when Rank function unavailable or TEXT_INDEX disabled
|
||||
_FULL_TEXT_FALLBACK_SCORE = 0.5
|
||||
|
||||
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
|
||||
super().__init__(collection_name)
|
||||
self.config = config
|
||||
|
|
@ -272,41 +275,131 @@ class IrisVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
"""Search documents by full-text using iFind index or fallback to LIKE search."""
|
||||
"""Search documents by full-text using iFind index with BM25 relevance scoring.
|
||||
|
||||
When IRIS_TEXT_INDEX is enabled, this method uses the auto-generated Rank
|
||||
function from %iFind.Index.Basic to calculate BM25 relevance scores. The Rank
|
||||
function is automatically created with naming: {schema}.{table_name}_{index}Rank
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
**kwargs: Optional parameters including top_k, document_ids_filter
|
||||
|
||||
Returns:
|
||||
List of Document objects with relevance scores in metadata["score"]
|
||||
"""
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
|
||||
with self._get_cursor() as cursor:
|
||||
if self.config.IRIS_TEXT_INDEX:
|
||||
# Use iFind full-text search with index
|
||||
# Use iFind full-text search with auto-generated Rank function
|
||||
text_index_name = f"idx_{self.table_name}_text"
|
||||
# IRIS removes underscores from function names
|
||||
table_no_underscore = self.table_name.replace("_", "")
|
||||
index_no_underscore = text_index_name.replace("_", "")
|
||||
rank_function = f"{self.schema}.{table_no_underscore}_{index_no_underscore}Rank"
|
||||
|
||||
# Build WHERE clause with document ID filter if provided
|
||||
where_clause = f"WHERE %ID %FIND search_index({text_index_name}, ?)"
|
||||
# First param for Rank function, second for FIND
|
||||
params = [query, query]
|
||||
|
||||
if document_ids_filter:
|
||||
# Add document ID filter
|
||||
placeholders = ",".join("?" * len(document_ids_filter))
|
||||
where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
SELECT TOP {top_k}
|
||||
id,
|
||||
text,
|
||||
meta,
|
||||
{rank_function}(%ID, ?) AS score
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE %ID %FIND search_index({text_index_name}, ?)
|
||||
{where_clause}
|
||||
ORDER BY score DESC
|
||||
"""
|
||||
cursor.execute(sql, (query,))
|
||||
|
||||
logger.debug(
|
||||
"iFind search: query='%s', index='%s', rank='%s'",
|
||||
query,
|
||||
text_index_name,
|
||||
rank_function,
|
||||
)
|
||||
|
||||
try:
|
||||
cursor.execute(sql, params)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Fallback to query without Rank function if it fails
|
||||
logger.warning(
|
||||
"Rank function '%s' failed, using fixed score",
|
||||
rank_function,
|
||||
exc_info=True,
|
||||
)
|
||||
sql_fallback = f"""
|
||||
SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
|
||||
FROM {self.schema}.{self.table_name}
|
||||
{where_clause}
|
||||
"""
|
||||
# Skip first param (for Rank function)
|
||||
cursor.execute(sql_fallback, params[1:])
|
||||
else:
|
||||
# Fallback to LIKE search (inefficient for large datasets)
|
||||
# Escape special characters for LIKE clause to prevent SQL injection
|
||||
from libs.helper import escape_like_pattern
|
||||
# Fallback to LIKE search (IRIS_TEXT_INDEX disabled)
|
||||
from libs.helper import ( # pylint: disable=import-outside-toplevel
|
||||
escape_like_pattern,
|
||||
)
|
||||
|
||||
escaped_query = escape_like_pattern(query)
|
||||
query_pattern = f"%{escaped_query}%"
|
||||
|
||||
# Build WHERE clause with document ID filter if provided
|
||||
where_clause = "WHERE text LIKE ? ESCAPE '\\\\'"
|
||||
params = [query_pattern]
|
||||
|
||||
if document_ids_filter:
|
||||
placeholders = ",".join("?" * len(document_ids_filter))
|
||||
where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
|
||||
params.extend(document_ids_filter)
|
||||
|
||||
sql = f"""
|
||||
SELECT TOP {top_k} id, text, meta
|
||||
SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
|
||||
FROM {self.schema}.{self.table_name}
|
||||
WHERE text LIKE ? ESCAPE '\\'
|
||||
{where_clause}
|
||||
ORDER BY LENGTH(text) ASC
|
||||
"""
|
||||
cursor.execute(sql, (query_pattern,))
|
||||
|
||||
logger.debug(
|
||||
"LIKE fallback (TEXT_INDEX disabled): query='%s'",
|
||||
query_pattern,
|
||||
)
|
||||
cursor.execute(sql, params)
|
||||
|
||||
docs = []
|
||||
for row in cursor.fetchall():
|
||||
if len(row) >= 3:
|
||||
metadata = json.loads(row[2]) if row[2] else {}
|
||||
docs.append(Document(page_content=row[1], metadata=metadata))
|
||||
# Expecting 4 columns: id, text, meta, score
|
||||
if len(row) >= 4:
|
||||
text_content = row[1]
|
||||
meta_str = row[2]
|
||||
score_value = row[3]
|
||||
|
||||
metadata = json.loads(meta_str) if meta_str else {}
|
||||
# Add score to metadata for hybrid search compatibility
|
||||
score = float(score_value) if score_value is not None else 0.0
|
||||
metadata["score"] = score
|
||||
|
||||
docs.append(Document(page_content=text_content, metadata=metadata))
|
||||
|
||||
logger.info(
|
||||
"Full-text search completed: query='%s', results=%d/%d",
|
||||
query,
|
||||
len(docs),
|
||||
top_k,
|
||||
)
|
||||
|
||||
if not docs:
|
||||
logger.info("Full-text search for '%s' returned no results", query)
|
||||
logger.warning("Full-text search for '%s' returned no results", query)
|
||||
|
||||
return docs
|
||||
|
||||
|
|
@ -370,7 +463,11 @@ class IrisVector(BaseVector):
|
|||
AS %iFind.Index.Basic
|
||||
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
|
||||
"""
|
||||
logger.info("Creating text index: %s with language: %s", text_index_name, language)
|
||||
logger.info(
|
||||
"Creating text index: %s with language: %s",
|
||||
text_index_name,
|
||||
language,
|
||||
)
|
||||
logger.info("SQL for text index: %s", sql_text_index)
|
||||
cursor.execute(sql_text_index)
|
||||
logger.info("Text index created successfully: %s", text_index_name)
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class ToolInvokeMessage(BaseModel):
|
|||
text: str
|
||||
|
||||
class JsonMessage(BaseModel):
|
||||
json_object: dict
|
||||
json_object: dict | list
|
||||
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
|
||||
|
||||
class BlobMessage(BaseModel):
|
||||
|
|
@ -144,7 +144,14 @@ class ToolInvokeMessage(BaseModel):
|
|||
end: bool = Field(..., description="Whether the chunk is the last chunk")
|
||||
|
||||
class FileMessage(BaseModel):
|
||||
pass
|
||||
file_marker: str = Field(default="file_marker")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_file_message(cls, values):
|
||||
if isinstance(values, dict) and "file_marker" not in values:
|
||||
raise ValueError("Invalid FileMessage: missing file_marker")
|
||||
return values
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
|
|
@ -234,10 +241,22 @@ class ToolInvokeMessage(BaseModel):
|
|||
|
||||
@field_validator("message", mode="before")
|
||||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
def decode_blob_message(cls, v, info: ValidationInfo):
|
||||
# 处理 blob 解码
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
with contextlib.suppress(Exception):
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
|
||||
# Force correct message type based on type field
|
||||
# Only wrap dict types to avoid wrapping already parsed Pydantic model objects
|
||||
if info.data and isinstance(info.data, dict) and isinstance(v, dict):
|
||||
msg_type = info.data.get("type")
|
||||
if msg_type == cls.MessageType.JSON:
|
||||
if "json_object" not in v:
|
||||
v = {"json_object": v}
|
||||
elif msg_type == cls.MessageType.FILE:
|
||||
v = {"file_marker": "file_marker"}
|
||||
|
||||
return v
|
||||
|
||||
@field_serializer("message")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable
|
||||
from copy import deepcopy
|
||||
from datetime import UTC, datetime
|
||||
|
|
@ -36,6 +37,8 @@ from extensions.ext_database import db
|
|||
from models.enums import CreatorUserRole
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolEngine:
|
||||
"""
|
||||
|
|
@ -123,25 +126,31 @@ class ToolEngine:
|
|||
# transform tool invoke message to get LLM friendly message
|
||||
return plain_text, message_files, meta
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
logger.error(e, exc_info=True)
|
||||
error_response = "Please check your tool provider credentials"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
|
||||
error_response = f"there is not a tool named {tool.entity.identity.name}"
|
||||
logger.error(e, exc_info=True)
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolParameterValidationError as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
logger.error(e, exc_info=True)
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
logger.error(e, exc_info=True)
|
||||
except ToolEngineInvokeError as e:
|
||||
meta = e.meta
|
||||
error_response = f"tool invoke error: {meta.error}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
logger.error(e, exc_info=True)
|
||||
return error_response, [], meta
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
|
|||
)
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
|
@ -28,21 +27,6 @@ from models.workflow import Workflow
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_resolve_user_from_request() -> Account | EndUser | None:
|
||||
"""
|
||||
Try to resolve user from Flask request context.
|
||||
|
||||
Returns None if not in a request context or if user is not available.
|
||||
"""
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
# Use _get_current_object() to dereference the proxy
|
||||
user = getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
# Check if we got a valid user object
|
||||
if user is not None and hasattr(user, "id"):
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
"""
|
||||
Workflow tool.
|
||||
|
|
@ -223,12 +207,6 @@ class WorkflowTool(Tool):
|
|||
Returns:
|
||||
Account | EndUser | None: The resolved user object, or None if resolution fails.
|
||||
"""
|
||||
# Try to resolve user from request context first
|
||||
user = _try_resolve_user_from_request()
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
# Fall back to database resolution
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
|
||||
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ Execution Context - Abstracted context management for workflow execution.
|
|||
"""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
|
|
@ -88,6 +89,7 @@ class ExecutionContext:
|
|||
self._app_context = app_context
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
self._local = threading.local()
|
||||
|
||||
@property
|
||||
def app_context(self) -> AppContext | None:
|
||||
|
|
@ -125,14 +127,16 @@ class ExecutionContext:
|
|||
|
||||
def __enter__(self) -> "ExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
self._cm = self.enter()
|
||||
self._cm.__enter__()
|
||||
cm = self.enter()
|
||||
self._local.cm = cm
|
||||
cm.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
cm = getattr(self._local, "cm", None)
|
||||
if cm is not None:
|
||||
cm.__exit__(*args)
|
||||
|
||||
|
||||
class NullAppContext(AppContext):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import time
|
|||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, final
|
||||
from uuid import uuid4
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
|
|
@ -113,7 +112,7 @@ class Worker(threading.Thread):
|
|||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=str(uuid4()),
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
|
|
|
|||
|
|
@ -235,7 +235,18 @@ class AgentNode(Node[AgentNodeData]):
|
|||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
|
|
@ -483,7 +494,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
|
|
@ -557,13 +568,18 @@ class AgentNode(Node[AgentNodeData]):
|
|||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
msg_metadata = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
|
|
@ -672,7 +688,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
|
|
|
|||
|
|
@ -469,12 +469,8 @@ class Node(Generic[NodeDataT]):
|
|||
import core.workflow.nodes as _nodes_pkg
|
||||
|
||||
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports
|
||||
# e.g. node_factory imports node_mapping which builds the mapping here.
|
||||
if _modname in {
|
||||
"core.workflow.nodes.node_factory",
|
||||
"core.workflow.nodes.node_mapping",
|
||||
}:
|
||||
# Avoid importing modules that depend on the registry to prevent circular imports.
|
||||
if _modname == "core.workflow.nodes.node_mapping":
|
||||
continue
|
||||
importlib.import_module(_modname)
|
||||
|
||||
|
|
|
|||
|
|
@ -301,7 +301,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
|||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -588,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
|
|
|
|||
|
|
@ -413,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]):
|
|||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
|
@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]):
|
|||
message.message.metadata = dict_metadata
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Any
|
|||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.file.models import File
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
|
|
@ -19,7 +20,6 @@ from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
|||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
from enum import StrEnum
|
||||
|
||||
|
||||
class HostedTrialProvider(StrEnum):
|
||||
"""
|
||||
Enum representing hosted model provider names for trial access.
|
||||
"""
|
||||
|
||||
OPENAI = "langgenius/openai/openai"
|
||||
ANTHROPIC = "langgenius/anthropic/anthropic"
|
||||
GEMINI = "langgenius/gemini/google"
|
||||
X = "langgenius/x/x"
|
||||
DEEPSEEK = "langgenius/deepseek/deepseek"
|
||||
TONGYI = "langgenius/tongyi/tongyi"
|
||||
|
||||
@property
|
||||
def config_key(self) -> str:
|
||||
"""Return the config key used in dify_config (e.g., HOSTED_{config_key}_PAID_ENABLED)."""
|
||||
if self == HostedTrialProvider.X:
|
||||
return "XAI"
|
||||
return self.name
|
||||
|
|
@ -4,6 +4,7 @@ from dify_app import DifyApp
|
|||
def init_app(app: DifyApp):
|
||||
from commands import (
|
||||
add_qdrant_index,
|
||||
archive_workflow_runs,
|
||||
clean_expired_messages,
|
||||
clean_workflow_runs,
|
||||
cleanup_orphaned_draft_variables,
|
||||
|
|
@ -11,6 +12,7 @@ def init_app(app: DifyApp):
|
|||
clear_orphaned_file_records,
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
|
|
@ -24,6 +26,7 @@ def init_app(app: DifyApp):
|
|||
reset_email,
|
||||
reset_encrypt_key_pair,
|
||||
reset_password,
|
||||
restore_workflow_runs,
|
||||
setup_datasource_oauth_client,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
|
|
@ -58,6 +61,9 @@ def init_app(app: DifyApp):
|
|||
setup_datasource_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
install_rag_pipeline_plugins,
|
||||
archive_workflow_runs,
|
||||
delete_archived_workflow_runs,
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
from fastopenapi.routers import FlaskRouter
|
||||
from flask_cors import CORS
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.fastopenapi import console_router
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
||||
|
||||
DOCS_PREFIX = "/fastopenapi"
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||
redoc_url = f"{DOCS_PREFIX}/redoc" if docs_enabled else None
|
||||
openapi_url = f"{DOCS_PREFIX}/openapi.json" if docs_enabled else None
|
||||
|
||||
router = FlaskRouter(
|
||||
app=app,
|
||||
docs_url=docs_url,
|
||||
redoc_url=redoc_url,
|
||||
openapi_url=openapi_url,
|
||||
openapi_version="3.0.0",
|
||||
title="Dify API (FastOpenAPI PoC)",
|
||||
version="1.0",
|
||||
description="FastOpenAPI proof of concept for Dify API",
|
||||
)
|
||||
|
||||
# Ensure route decorators are evaluated.
|
||||
import controllers.console.ping as ping_module
|
||||
|
||||
_ = ping_module
|
||||
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(AUTHENTICATED_HEADERS),
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
)
|
||||
app.extensions["fastopenapi"] = router
|
||||
|
|
@ -2,7 +2,12 @@ from flask_restx import Namespace, fields
|
|||
|
||||
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
|
||||
from fields.member_fields import build_simple_account_model, simple_account_fields
|
||||
from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields
|
||||
from fields.workflow_run_fields import (
|
||||
build_workflow_run_for_archived_log_model,
|
||||
build_workflow_run_for_log_model,
|
||||
workflow_run_for_archived_log_fields,
|
||||
workflow_run_for_log_fields,
|
||||
)
|
||||
from libs.helper import TimestampField
|
||||
|
||||
workflow_app_log_partial_fields = {
|
||||
|
|
@ -34,6 +39,33 @@ def build_workflow_app_log_partial_model(api_or_ns: Namespace):
|
|||
return api_or_ns.model("WorkflowAppLogPartial", copied_fields)
|
||||
|
||||
|
||||
workflow_archived_log_partial_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True),
|
||||
"trigger_metadata": fields.Raw,
|
||||
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_archived_log_partial_model(api_or_ns: Namespace):
|
||||
"""Build the workflow archived log partial model for the API or Namespace."""
|
||||
workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns)
|
||||
simple_account_model = build_simple_account_model(api_or_ns)
|
||||
simple_end_user_model = build_simple_end_user_model(api_or_ns)
|
||||
|
||||
copied_fields = workflow_archived_log_partial_fields.copy()
|
||||
copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True)
|
||||
copied_fields["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
copied_fields["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields)
|
||||
|
||||
|
||||
workflow_app_log_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
|
|
@ -51,3 +83,21 @@ def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
|
|||
copied_fields = workflow_app_log_pagination_fields.copy()
|
||||
copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model))
|
||||
return api_or_ns.model("WorkflowAppLogPagination", copied_fields)
|
||||
|
||||
|
||||
workflow_archived_log_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(workflow_archived_log_partial_fields)),
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_archived_log_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the workflow archived log pagination model for the API or Namespace."""
|
||||
workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns)
|
||||
|
||||
copied_fields = workflow_archived_log_pagination_fields.copy()
|
||||
copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model))
|
||||
return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,19 @@ def build_workflow_run_for_log_model(api_or_ns: Namespace):
|
|||
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)
|
||||
|
||||
|
||||
workflow_run_for_archived_log_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"triggered_from": fields.String,
|
||||
"elapsed_time": fields.Float,
|
||||
"total_tokens": fields.Integer,
|
||||
}
|
||||
|
||||
|
||||
def build_workflow_run_for_archived_log_model(api_or_ns: Namespace):
|
||||
return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields)
|
||||
|
||||
|
||||
workflow_run_for_list_fields = {
|
||||
"id": fields.String,
|
||||
"version": fields.String,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ to S3-compatible object storage.
|
|||
|
||||
import base64
|
||||
import datetime
|
||||
import gzip
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
|
@ -39,7 +38,7 @@ class ArchiveStorage:
|
|||
"""
|
||||
S3-compatible storage client for archiving or exporting.
|
||||
|
||||
This client provides methods for storing and retrieving archived data in JSONL+gzip format.
|
||||
This client provides methods for storing and retrieving archived data in JSONL format.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket: str):
|
||||
|
|
@ -69,7 +68,10 @@ class ArchiveStorage:
|
|||
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
|
||||
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
|
||||
region_name=dify_config.ARCHIVE_STORAGE_REGION,
|
||||
config=Config(s3={"addressing_style": "path"}),
|
||||
config=Config(
|
||||
s3={"addressing_style": "path"},
|
||||
max_pool_connections=64,
|
||||
),
|
||||
)
|
||||
|
||||
# Verify bucket accessibility
|
||||
|
|
@ -100,12 +102,18 @@ class ArchiveStorage:
|
|||
"""
|
||||
checksum = hashlib.md5(data).hexdigest()
|
||||
try:
|
||||
self.client.put_object(
|
||||
response = self.client.put_object(
|
||||
Bucket=self.bucket,
|
||||
Key=key,
|
||||
Body=data,
|
||||
ContentMD5=self._content_md5(data),
|
||||
)
|
||||
etag = response.get("ETag")
|
||||
if not etag:
|
||||
raise ArchiveStorageError(f"Missing ETag for '{key}'")
|
||||
normalized_etag = etag.strip('"')
|
||||
if normalized_etag != checksum:
|
||||
raise ArchiveStorageError(f"ETag mismatch for '{key}': expected={checksum}, actual={normalized_etag}")
|
||||
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
|
||||
return checksum
|
||||
except ClientError as e:
|
||||
|
|
@ -240,19 +248,18 @@ class ArchiveStorage:
|
|||
return base64.b64encode(hashlib.md5(data).digest()).decode()
|
||||
|
||||
@staticmethod
|
||||
def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
|
||||
def serialize_to_jsonl(records: list[dict[str, Any]]) -> bytes:
|
||||
"""
|
||||
Serialize records to gzipped JSONL format.
|
||||
Serialize records to JSONL format.
|
||||
|
||||
Args:
|
||||
records: List of dictionaries to serialize
|
||||
|
||||
Returns:
|
||||
Gzipped JSONL bytes
|
||||
JSONL bytes
|
||||
"""
|
||||
lines = []
|
||||
for record in records:
|
||||
# Convert datetime objects to ISO format strings
|
||||
serialized = ArchiveStorage._serialize_record(record)
|
||||
lines.append(orjson.dumps(serialized))
|
||||
|
||||
|
|
@ -260,23 +267,22 @@ class ArchiveStorage:
|
|||
if jsonl_content:
|
||||
jsonl_content += b"\n"
|
||||
|
||||
return gzip.compress(jsonl_content)
|
||||
return jsonl_content
|
||||
|
||||
@staticmethod
|
||||
def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
|
||||
def deserialize_from_jsonl(data: bytes) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Deserialize gzipped JSONL data to records.
|
||||
Deserialize JSONL data to records.
|
||||
|
||||
Args:
|
||||
data: Gzipped JSONL bytes
|
||||
data: JSONL bytes
|
||||
|
||||
Returns:
|
||||
List of dictionaries
|
||||
"""
|
||||
jsonl_content = gzip.decompress(data)
|
||||
records = []
|
||||
|
||||
for line in jsonl_content.splitlines():
|
||||
for line in data.splitlines():
|
||||
if line:
|
||||
records.append(orjson.loads(line))
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
"""make message annotation question not nullable
|
||||
|
||||
Revision ID: 9e6fa5cbcd80
|
||||
Revises: 03f8dcbc611e
|
||||
Create Date: 2025-11-06 16:03:54.549378
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '9e6fa5cbcd80'
|
||||
down_revision = '288345cd01d1'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
bind = op.get_bind()
|
||||
message_annotations = sa.table(
|
||||
"message_annotations",
|
||||
sa.column("id", sa.String),
|
||||
sa.column("message_id", sa.String),
|
||||
sa.column("question", sa.Text),
|
||||
)
|
||||
messages = sa.table(
|
||||
"messages",
|
||||
sa.column("id", sa.String),
|
||||
sa.column("query", sa.Text),
|
||||
)
|
||||
update_question_from_message = (
|
||||
sa.update(message_annotations)
|
||||
.where(
|
||||
sa.and_(
|
||||
message_annotations.c.question.is_(None),
|
||||
message_annotations.c.message_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.values(
|
||||
question=sa.select(sa.func.coalesce(messages.c.query, ""))
|
||||
.where(messages.c.id == message_annotations.c.message_id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
)
|
||||
bind.execute(update_question_from_message)
|
||||
|
||||
fill_remaining_questions = (
|
||||
sa.update(message_annotations)
|
||||
.where(message_annotations.c.question.is_(None))
|
||||
.values(question="")
|
||||
)
|
||||
bind.execute(fill_remaining_questions)
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
|
||||
batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=True)
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""add table explore banner and trial
|
||||
|
||||
Revision ID: f9f6d18a37f9
|
||||
Revises: 9e6fa5cbcd80
|
||||
Create Date: 2026-01-017 11:10:18.079355
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f9f6d18a37f9'
|
||||
down_revision = '9e6fa5cbcd80'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('account_trial_app_records',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('account_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('count', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
|
||||
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
|
||||
)
|
||||
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
|
||||
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
|
||||
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
|
||||
|
||||
op.create_table('exporle_banners',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('content', sa.JSON(), nullable=False),
|
||||
sa.Column('link', sa.String(length=255), nullable=False),
|
||||
sa.Column('sort', sa.Integer(), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
|
||||
)
|
||||
op.create_table('trial_apps',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('trial_limit', sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
|
||||
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
|
||||
)
|
||||
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
|
||||
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
|
||||
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
|
||||
batch_op.drop_index('trial_app_tenant_id_idx')
|
||||
batch_op.drop_index('trial_app_app_id_idx')
|
||||
|
||||
op.drop_table('trial_apps')
|
||||
op.drop_table('exporle_banners')
|
||||
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
|
||||
batch_op.drop_index('account_trial_app_record_app_id_idx')
|
||||
batch_op.drop_index('account_trial_app_record_account_id_idx')
|
||||
|
||||
op.drop_table('account_trial_app_records')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""create workflow_archive_logs
|
||||
|
||||
Revision ID: 9d77545f524e
|
||||
Revises: f9f6d18a37f9
|
||||
Create Date: 2026-01-06 17:18:56.292479
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '9d77545f524e'
|
||||
down_revision = 'f9f6d18a37f9'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
if _is_pg(conn):
|
||||
op.create_table('workflow_archive_logs',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
sa.Column('log_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_by_role', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('log_created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('log_created_from', sa.String(length=255), nullable=True),
|
||||
sa.Column('run_version', sa.String(length=255), nullable=False),
|
||||
sa.Column('run_status', sa.String(length=255), nullable=False),
|
||||
sa.Column('run_triggered_from', sa.String(length=255), nullable=False),
|
||||
sa.Column('run_error', models.types.LongText(), nullable=True),
|
||||
sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
|
||||
sa.Column('run_created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('run_finished_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True),
|
||||
sa.Column('trigger_metadata', models.types.LongText(), nullable=True),
|
||||
sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey')
|
||||
)
|
||||
else:
|
||||
op.create_table('workflow_archive_logs',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('log_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('workflow_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_by_role', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('log_created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('log_created_from', sa.String(length=255), nullable=True),
|
||||
sa.Column('run_version', sa.String(length=255), nullable=False),
|
||||
sa.Column('run_status', sa.String(length=255), nullable=False),
|
||||
sa.Column('run_triggered_from', sa.String(length=255), nullable=False),
|
||||
sa.Column('run_error', models.types.LongText(), nullable=True),
|
||||
sa.Column('run_elapsed_time', sa.Float(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('run_total_tokens', sa.BigInteger(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('run_total_steps', sa.Integer(), server_default=sa.text('0'), nullable=True),
|
||||
sa.Column('run_created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('run_finished_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('run_exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True),
|
||||
sa.Column('trigger_metadata', models.types.LongText(), nullable=True),
|
||||
sa.Column('archived_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_archive_log_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_archive_log_app_idx', ['tenant_id', 'app_id'], unique=False)
|
||||
batch_op.create_index('workflow_archive_log_run_created_at_idx', ['run_created_at'], unique=False)
|
||||
batch_op.create_index('workflow_archive_log_workflow_run_id_idx', ['workflow_run_id'], unique=False)
|
||||
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_archive_logs', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_archive_log_workflow_run_id_idx')
|
||||
batch_op.drop_index('workflow_archive_log_run_created_at_idx')
|
||||
batch_op.drop_index('workflow_archive_log_app_idx')
|
||||
|
||||
op.drop_table('workflow_archive_logs')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -35,6 +35,7 @@ from .enums import (
|
|||
WorkflowTriggerStatus,
|
||||
)
|
||||
from .model import (
|
||||
AccountTrialAppRecord,
|
||||
ApiRequest,
|
||||
ApiToken,
|
||||
App,
|
||||
|
|
@ -47,6 +48,7 @@ from .model import (
|
|||
DatasetRetrieverResource,
|
||||
DifySetup,
|
||||
EndUser,
|
||||
ExporleBanner,
|
||||
IconType,
|
||||
InstalledApp,
|
||||
Message,
|
||||
|
|
@ -62,6 +64,7 @@ from .model import (
|
|||
TagBinding,
|
||||
TenantCreditPool,
|
||||
TraceAppConfig,
|
||||
TrialApp,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
|
|
@ -100,6 +103,7 @@ from .workflow import (
|
|||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowArchiveLog,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
|
|
@ -114,6 +118,7 @@ __all__ = [
|
|||
"Account",
|
||||
"AccountIntegrate",
|
||||
"AccountStatus",
|
||||
"AccountTrialAppRecord",
|
||||
"ApiRequest",
|
||||
"ApiToken",
|
||||
"ApiToolProvider",
|
||||
|
|
@ -150,6 +155,7 @@ __all__ = [
|
|||
"DocumentSegment",
|
||||
"Embedding",
|
||||
"EndUser",
|
||||
"ExporleBanner",
|
||||
"ExternalKnowledgeApis",
|
||||
"ExternalKnowledgeBindings",
|
||||
"IconType",
|
||||
|
|
@ -188,6 +194,7 @@ __all__ = [
|
|||
"ToolLabelBinding",
|
||||
"ToolModelInvoke",
|
||||
"TraceAppConfig",
|
||||
"TrialApp",
|
||||
"TriggerOAuthSystemClient",
|
||||
"TriggerOAuthTenantClient",
|
||||
"TriggerSubscription",
|
||||
|
|
@ -197,6 +204,7 @@ __all__ = [
|
|||
"Workflow",
|
||||
"WorkflowAppLog",
|
||||
"WorkflowAppLogCreatedFrom",
|
||||
"WorkflowArchiveLog",
|
||||
"WorkflowNodeExecutionModel",
|
||||
"WorkflowNodeExecutionOffload",
|
||||
"WorkflowNodeExecutionTriggeredFrom",
|
||||
|
|
|
|||
|
|
@ -315,40 +315,48 @@ class App(Base):
|
|||
return None
|
||||
|
||||
|
||||
class AppModelConfig(Base):
|
||||
class AppModelConfig(TypeBase):
|
||||
__tablename__ = "app_model_configs"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id"))
|
||||
|
||||
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
provider = mapped_column(String(255), nullable=True)
|
||||
model_id = mapped_column(String(255), nullable=True)
|
||||
configs = mapped_column(sa.JSON, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
configs: Mapped[Any | None] = mapped_column(sa.JSON, nullable=True, default=None)
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
opening_statement = mapped_column(LongText)
|
||||
suggested_questions = mapped_column(LongText)
|
||||
suggested_questions_after_answer = mapped_column(LongText)
|
||||
speech_to_text = mapped_column(LongText)
|
||||
text_to_speech = mapped_column(LongText)
|
||||
more_like_this = mapped_column(LongText)
|
||||
model = mapped_column(LongText)
|
||||
user_input_form = mapped_column(LongText)
|
||||
dataset_query_variable = mapped_column(String(255))
|
||||
pre_prompt = mapped_column(LongText)
|
||||
agent_mode = mapped_column(LongText)
|
||||
sensitive_word_avoidance = mapped_column(LongText)
|
||||
retriever_resource = mapped_column(LongText)
|
||||
prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'"))
|
||||
chat_prompt_config = mapped_column(LongText)
|
||||
completion_prompt_config = mapped_column(LongText)
|
||||
dataset_configs = mapped_column(LongText)
|
||||
external_data_tools = mapped_column(LongText)
|
||||
file_upload = mapped_column(LongText)
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
opening_statement: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
suggested_questions: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
suggested_questions_after_answer: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
speech_to_text: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
text_to_speech: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
more_like_this: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
model: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
user_input_form: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
dataset_query_variable: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
pre_prompt: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
agent_mode: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
prompt_type: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'simple'"), default="simple"
|
||||
)
|
||||
chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
dataset_configs: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
external_data_tools: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
file_upload: Mapped[str | None] = mapped_column(LongText, default=None)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
|
|
@ -603,6 +611,64 @@ class InstalledApp(TypeBase):
|
|||
return tenant
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
__tablename__ = "trial_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
|
||||
sa.Index("trial_app_app_id_idx", "app_id"),
|
||||
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
__tablename__ = "account_trial_app_records"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
|
||||
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
|
||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||
)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return user
|
||||
|
||||
|
||||
class ExporleBanner(Base):
|
||||
__tablename__ = "exporle_banners"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
content = mapped_column(sa.JSON, nullable=False)
|
||||
link = mapped_column(String(255), nullable=False)
|
||||
sort = mapped_column(sa.Integer, nullable=False)
|
||||
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
|
||||
|
||||
|
||||
class OAuthProviderApp(TypeBase):
|
||||
"""
|
||||
Globally shared OAuth provider app information.
|
||||
|
|
@ -752,8 +818,8 @@ class Conversation(Base):
|
|||
override_model_configs = json.loads(self.override_model_configs)
|
||||
|
||||
if "model" in override_model_configs:
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
|
||||
# where is app_id?
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
|
|
@ -1423,7 +1489,7 @@ class MessageAnnotation(Base):
|
|||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
question: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
question: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
|
|
|||
|
|
@ -226,8 +226,7 @@ class Workflow(Base): # bug
|
|||
#
|
||||
# Currently, the following functions / methods would mutate the returned dict:
|
||||
#
|
||||
# - `_get_graph_and_variable_pool_of_single_iteration`.
|
||||
# - `_get_graph_and_variable_pool_of_single_loop`.
|
||||
# - `_get_graph_and_variable_pool_for_single_node_run`.
|
||||
return json.loads(self.graph) if self.graph else {}
|
||||
|
||||
def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
|
||||
|
|
@ -1163,6 +1162,69 @@ class WorkflowAppLog(TypeBase):
|
|||
}
|
||||
|
||||
|
||||
class WorkflowArchiveLog(TypeBase):
|
||||
"""
|
||||
Workflow archive log.
|
||||
|
||||
Stores essential workflow run snapshot data for archived app logs.
|
||||
|
||||
Field sources:
|
||||
- Shared fields (tenant/app/workflow/run ids, created_by*): from WorkflowRun for consistency.
|
||||
- log_* fields: from WorkflowAppLog when present; null if the run has no app log.
|
||||
- run_* fields: workflow run snapshot fields from WorkflowRun.
|
||||
- trigger_metadata: snapshot from WorkflowTriggerLog when present.
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_archive_logs"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_archive_log_pkey"),
|
||||
sa.Index("workflow_archive_log_app_idx", "tenant_id", "app_id"),
|
||||
sa.Index("workflow_archive_log_workflow_run_id_idx", "workflow_run_id"),
|
||||
sa.Index("workflow_archive_log_run_created_at_idx", "run_created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_by_role: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
log_created_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
log_created_from: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
run_version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_status: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
run_error: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||
run_total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
run_created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
||||
run_finished_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
run_exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
trigger_metadata: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
archived_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
@property
|
||||
def workflow_run_summary(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.workflow_run_id,
|
||||
"status": self.run_status,
|
||||
"triggered_from": self.run_triggered_from,
|
||||
"elapsed_time": self.run_elapsed_time,
|
||||
"total_tokens": self.run_total_tokens,
|
||||
}
|
||||
|
||||
|
||||
class ConversationVariable(TypeBase):
|
||||
__tablename__ = "workflow_conversation_variables"
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ dependencies = [
|
|||
"gunicorn~=23.0.0",
|
||||
"httpx[socks]~=0.27.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.41.1",
|
||||
"json-repair>=0.55.1",
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
|
|
@ -93,6 +93,7 @@ dependencies = [
|
|||
"weaviate-client==4.17.0",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
],
|
||||
"typeCheckingMode": "strict",
|
||||
"allowedUntypedLibraries": [
|
||||
"fastopenapi",
|
||||
"flask_restx",
|
||||
"flask_login",
|
||||
"opentelemetry.instrumentation.celery",
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from typing import Protocol
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
|
||||
|
|
@ -209,3 +209,23 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
|||
The number of executions deleted
|
||||
"""
|
||||
...
|
||||
|
||||
def get_offloads_by_execution_ids(
|
||||
self,
|
||||
session: Session,
|
||||
node_execution_ids: Sequence[str],
|
||||
) -> Sequence[WorkflowNodeExecutionOffload]:
|
||||
"""
|
||||
Get offload records by node execution IDs.
|
||||
|
||||
This method retrieves workflow node execution offload records
|
||||
that belong to the given node execution IDs.
|
||||
|
||||
Args:
|
||||
session: The database session to use
|
||||
node_execution_ids: List of node execution IDs to filter by
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionOffload instances
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ from core.workflow.enums import WorkflowType
|
|||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.types import (
|
||||
AverageInteractionStats,
|
||||
|
|
@ -270,6 +270,58 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_archived_run_ids(
|
||||
self,
|
||||
session: Session,
|
||||
run_ids: Sequence[str],
|
||||
) -> set[str]:
|
||||
"""
|
||||
Fetch workflow run IDs that already have archive log records.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_archived_logs_by_time_range(
|
||||
self,
|
||||
session: Session,
|
||||
tenant_ids: Sequence[str] | None,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int,
|
||||
) -> Sequence[WorkflowArchiveLog]:
|
||||
"""
|
||||
Fetch archived workflow logs by time range for restore.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_archived_log_by_run_id(
|
||||
self,
|
||||
run_id: str,
|
||||
) -> WorkflowArchiveLog | None:
|
||||
"""
|
||||
Fetch a workflow archive log by workflow run ID.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_archive_log_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
run_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Delete archive log by workflow run ID.
|
||||
|
||||
Used after restoring a workflow run to remove the archive log record,
|
||||
allowing the run to be archived again if needed.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
run_id: Workflow run ID
|
||||
|
||||
Returns:
|
||||
Number of records deleted (0 or 1)
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
|
|
@ -282,6 +334,61 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_pause_records_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
run_id: str,
|
||||
) -> Sequence[WorkflowPause]:
|
||||
"""
|
||||
Fetch workflow pause records by workflow run ID.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_pause_reason_records_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
pause_ids: Sequence[str],
|
||||
) -> Sequence[WorkflowPauseReason]:
|
||||
"""
|
||||
Fetch workflow pause reason records by pause IDs.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_app_logs_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
run_id: str,
|
||||
) -> Sequence[WorkflowAppLog]:
|
||||
"""
|
||||
Fetch workflow app logs by workflow run ID.
|
||||
"""
|
||||
...
|
||||
|
||||
def create_archive_logs(
|
||||
self,
|
||||
session: Session,
|
||||
run: WorkflowRun,
|
||||
app_logs: Sequence[WorkflowAppLog],
|
||||
trigger_metadata: str | None,
|
||||
) -> int:
|
||||
"""
|
||||
Create archive log records for a workflow run.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_archived_runs_by_time_range(
|
||||
self,
|
||||
session: Session,
|
||||
tenant_ids: Sequence[str] | None,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Return workflow runs that already have archive logs, for cleanup of `workflow_runs`.
|
||||
"""
|
||||
...
|
||||
|
||||
def count_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
|
|
|
|||
|
|
@ -351,3 +351,27 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||
)
|
||||
|
||||
return int(node_executions_count), int(offloads_count)
|
||||
|
||||
@staticmethod
|
||||
def get_by_run(
|
||||
session: Session,
|
||||
run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Fetch node executions for a run using workflow_run_id.
|
||||
"""
|
||||
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.workflow_run_id == run_id)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
def get_offloads_by_execution_ids(
|
||||
self,
|
||||
session: Session,
|
||||
node_execution_ids: Sequence[str],
|
||||
) -> Sequence[WorkflowNodeExecutionOffload]:
|
||||
if not node_execution_ids:
|
||||
return []
|
||||
|
||||
stmt = select(WorkflowNodeExecutionOffload).where(
|
||||
WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)
|
||||
)
|
||||
return list(session.scalars(stmt))
|
||||
|
|
|
|||
|
|
@ -40,14 +40,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
|||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import (
|
||||
WorkflowAppLog,
|
||||
WorkflowPauseReason,
|
||||
WorkflowRun,
|
||||
)
|
||||
from models.workflow import (
|
||||
WorkflowPause as WorkflowPauseModel,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.types import (
|
||||
|
|
@ -369,6 +362,53 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def get_archived_run_ids(
|
||||
self,
|
||||
session: Session,
|
||||
run_ids: Sequence[str],
|
||||
) -> set[str]:
|
||||
if not run_ids:
|
||||
return set()
|
||||
|
||||
stmt = select(WorkflowArchiveLog.workflow_run_id).where(WorkflowArchiveLog.workflow_run_id.in_(run_ids))
|
||||
return set(session.scalars(stmt).all())
|
||||
|
||||
def get_archived_log_by_run_id(
|
||||
self,
|
||||
run_id: str,
|
||||
) -> WorkflowArchiveLog | None:
|
||||
with self._session_maker() as session:
|
||||
stmt = select(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id).limit(1)
|
||||
return session.scalar(stmt)
|
||||
|
||||
def delete_archive_log_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
run_id: str,
|
||||
) -> int:
|
||||
stmt = delete(WorkflowArchiveLog).where(WorkflowArchiveLog.workflow_run_id == run_id)
|
||||
result = session.execute(stmt)
|
||||
return cast(CursorResult, result).rowcount or 0
|
||||
|
||||
def get_pause_records_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
run_id: str,
|
||||
) -> Sequence[WorkflowPause]:
|
||||
stmt = select(WorkflowPause).where(WorkflowPause.workflow_run_id == run_id)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
def get_pause_reason_records_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
pause_ids: Sequence[str],
|
||||
) -> Sequence[WorkflowPauseReason]:
|
||||
if not pause_ids:
|
||||
return []
|
||||
|
||||
stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
def delete_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
|
|
@ -396,9 +436,8 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)))
|
||||
app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0
|
||||
|
||||
pause_ids = session.scalars(
|
||||
select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
|
||||
).all()
|
||||
pause_stmt = select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids))
|
||||
pause_ids = session.scalars(pause_stmt).all()
|
||||
pause_reasons_deleted = 0
|
||||
pauses_deleted = 0
|
||||
|
||||
|
|
@ -407,7 +446,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids))
|
||||
)
|
||||
pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0
|
||||
pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids)))
|
||||
pauses_result = session.execute(delete(WorkflowPause).where(WorkflowPause.id.in_(pause_ids)))
|
||||
pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
|
||||
|
||||
trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
|
||||
|
|
@ -427,6 +466,124 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
"pause_reasons": pause_reasons_deleted,
|
||||
}
|
||||
|
||||
def get_app_logs_by_run_id(
|
||||
self,
|
||||
session: Session,
|
||||
run_id: str,
|
||||
) -> Sequence[WorkflowAppLog]:
|
||||
stmt = select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == run_id)
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
def create_archive_logs(
|
||||
self,
|
||||
session: Session,
|
||||
run: WorkflowRun,
|
||||
app_logs: Sequence[WorkflowAppLog],
|
||||
trigger_metadata: str | None,
|
||||
) -> int:
|
||||
if not app_logs:
|
||||
archive_log = WorkflowArchiveLog(
|
||||
log_id=None,
|
||||
log_created_at=None,
|
||||
log_created_from=None,
|
||||
tenant_id=run.tenant_id,
|
||||
app_id=run.app_id,
|
||||
workflow_id=run.workflow_id,
|
||||
workflow_run_id=run.id,
|
||||
created_by_role=run.created_by_role,
|
||||
created_by=run.created_by,
|
||||
run_version=run.version,
|
||||
run_status=run.status,
|
||||
run_triggered_from=run.triggered_from,
|
||||
run_error=run.error,
|
||||
run_elapsed_time=run.elapsed_time,
|
||||
run_total_tokens=run.total_tokens,
|
||||
run_total_steps=run.total_steps,
|
||||
run_created_at=run.created_at,
|
||||
run_finished_at=run.finished_at,
|
||||
run_exceptions_count=run.exceptions_count,
|
||||
trigger_metadata=trigger_metadata,
|
||||
)
|
||||
session.add(archive_log)
|
||||
return 1
|
||||
|
||||
archive_logs = [
|
||||
WorkflowArchiveLog(
|
||||
log_id=app_log.id,
|
||||
log_created_at=app_log.created_at,
|
||||
log_created_from=app_log.created_from,
|
||||
tenant_id=run.tenant_id,
|
||||
app_id=run.app_id,
|
||||
workflow_id=run.workflow_id,
|
||||
workflow_run_id=run.id,
|
||||
created_by_role=run.created_by_role,
|
||||
created_by=run.created_by,
|
||||
run_version=run.version,
|
||||
run_status=run.status,
|
||||
run_triggered_from=run.triggered_from,
|
||||
run_error=run.error,
|
||||
run_elapsed_time=run.elapsed_time,
|
||||
run_total_tokens=run.total_tokens,
|
||||
run_total_steps=run.total_steps,
|
||||
run_created_at=run.created_at,
|
||||
run_finished_at=run.finished_at,
|
||||
run_exceptions_count=run.exceptions_count,
|
||||
trigger_metadata=trigger_metadata,
|
||||
)
|
||||
for app_log in app_logs
|
||||
]
|
||||
session.add_all(archive_logs)
|
||||
return len(archive_logs)
|
||||
|
||||
def get_archived_runs_by_time_range(
|
||||
self,
|
||||
session: Session,
|
||||
tenant_ids: Sequence[str] | None,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Retrieves WorkflowRun records by joining workflow_archive_logs.
|
||||
|
||||
Used to identify runs that are already archived and ready for deletion.
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowRun)
|
||||
.join(WorkflowArchiveLog, WorkflowArchiveLog.workflow_run_id == WorkflowRun.id)
|
||||
.where(
|
||||
WorkflowArchiveLog.run_created_at >= start_date,
|
||||
WorkflowArchiveLog.run_created_at < end_date,
|
||||
)
|
||||
.order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
if tenant_ids:
|
||||
stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
def get_archived_logs_by_time_range(
|
||||
self,
|
||||
session: Session,
|
||||
tenant_ids: Sequence[str] | None,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int,
|
||||
) -> Sequence[WorkflowArchiveLog]:
|
||||
# Returns WorkflowArchiveLog rows directly; use this when workflow_runs may be deleted.
|
||||
stmt = (
|
||||
select(WorkflowArchiveLog)
|
||||
.where(
|
||||
WorkflowArchiveLog.run_created_at >= start_date,
|
||||
WorkflowArchiveLog.run_created_at < end_date,
|
||||
)
|
||||
.order_by(WorkflowArchiveLog.run_created_at.asc(), WorkflowArchiveLog.workflow_run_id.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
if tenant_ids:
|
||||
stmt = stmt.where(WorkflowArchiveLog.tenant_id.in_(tenant_ids))
|
||||
return list(session.scalars(stmt))
|
||||
|
||||
def count_runs_with_related(
|
||||
self,
|
||||
runs: Sequence[WorkflowRun],
|
||||
|
|
@ -459,7 +616,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
)
|
||||
|
||||
pause_ids = session.scalars(
|
||||
select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
|
||||
select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(run_ids))
|
||||
).all()
|
||||
pauses_count = len(pause_ids)
|
||||
pause_reasons_count = 0
|
||||
|
|
@ -511,9 +668,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
ValueError: If workflow_run_id is invalid or workflow run doesn't exist
|
||||
RuntimeError: If workflow is already paused or in invalid state
|
||||
"""
|
||||
previous_pause_model_query = select(WorkflowPauseModel).where(
|
||||
WorkflowPauseModel.workflow_run_id == workflow_run_id
|
||||
)
|
||||
previous_pause_model_query = select(WorkflowPause).where(WorkflowPause.workflow_run_id == workflow_run_id)
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the workflow run
|
||||
workflow_run = session.get(WorkflowRun, workflow_run_id)
|
||||
|
|
@ -538,7 +693,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
# Upload the state file
|
||||
|
||||
# Create the pause record
|
||||
pause_model = WorkflowPauseModel()
|
||||
pause_model = WorkflowPause()
|
||||
pause_model.id = str(uuidv7())
|
||||
pause_model.workflow_id = workflow_run.workflow_id
|
||||
pause_model.workflow_run_id = workflow_run.id
|
||||
|
|
@ -710,13 +865,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
"""
|
||||
with self._session_maker() as session, session.begin():
|
||||
# Get the pause model by ID
|
||||
pause_model = session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model = session.get(WorkflowPause, pause_entity.id)
|
||||
if pause_model is None:
|
||||
raise _WorkflowRunError(f"WorkflowPause not found: {pause_entity.id}")
|
||||
self._delete_pause_model(session, pause_model)
|
||||
|
||||
@staticmethod
|
||||
def _delete_pause_model(session: Session, pause_model: WorkflowPauseModel):
|
||||
def _delete_pause_model(session: Session, pause_model: WorkflowPause):
|
||||
storage.delete(pause_model.state_object_key)
|
||||
|
||||
# Delete the pause record
|
||||
|
|
@ -751,15 +906,15 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
_limit: int = limit or 1000
|
||||
pruned_record_ids: list[str] = []
|
||||
cond = or_(
|
||||
WorkflowPauseModel.created_at < expiration,
|
||||
WorkflowPause.created_at < expiration,
|
||||
and_(
|
||||
WorkflowPauseModel.resumed_at.is_not(null()),
|
||||
WorkflowPauseModel.resumed_at < resumption_expiration,
|
||||
WorkflowPause.resumed_at.is_not(null()),
|
||||
WorkflowPause.resumed_at < resumption_expiration,
|
||||
),
|
||||
)
|
||||
# First, collect pause records to delete with their state files
|
||||
# Expired pauses (created before expiration time)
|
||||
stmt = select(WorkflowPauseModel).where(cond).limit(_limit)
|
||||
stmt = select(WorkflowPause).where(cond).limit(_limit)
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
# Old resumed pauses (resumed more than resumption_duration ago)
|
||||
|
|
@ -770,7 +925,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
# Delete state files from storage
|
||||
for pause in pauses_to_delete:
|
||||
with self._session_maker(expire_on_commit=False) as session, session.begin():
|
||||
# todo: this issues a separate query for each WorkflowPauseModel record.
|
||||
# todo: this issues a separate query for each WorkflowPause record.
|
||||
# consider batching this lookup.
|
||||
try:
|
||||
storage.delete(pause.state_object_key)
|
||||
|
|
@ -1022,7 +1177,7 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
|||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pause_model: WorkflowPauseModel,
|
||||
pause_model: WorkflowPause,
|
||||
reason_models: Sequence[WorkflowPauseReason],
|
||||
human_input_form: Sequence = (),
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,11 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
|||
|
||||
return self.session.scalar(query)
|
||||
|
||||
def list_by_run_id(self, run_id: str) -> Sequence[WorkflowTriggerLog]:
|
||||
"""List trigger logs for a workflow run."""
|
||||
query = select(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id == run_id)
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def get_failed_for_retry(
|
||||
self, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||
) -> Sequence[WorkflowTriggerLog]:
|
||||
|
|
|
|||
|
|
@ -209,8 +209,12 @@ class AppAnnotationService:
|
|||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
question = args.get("question")
|
||||
if question is None:
|
||||
raise ValueError("'question' is required")
|
||||
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
|
||||
app_id=app.id, content=args["answer"], question=question, account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
|
|
@ -219,7 +223,7 @@ class AppAnnotationService:
|
|||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
question,
|
||||
current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
|
|
@ -244,8 +248,12 @@ class AppAnnotationService:
|
|||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
question = args.get("question")
|
||||
if question is None:
|
||||
raise ValueError("'question' is required")
|
||||
|
||||
annotation.content = args["answer"]
|
||||
annotation.question = args["question"]
|
||||
annotation.question = question
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
|
|
|
|||
|
|
@ -521,12 +521,10 @@ class AppDslService:
|
|||
raise ValueError("Missing model_config for chat/agent-chat/completion app")
|
||||
# Initialize or update model config
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig().from_model_config_dict(model_config)
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(model_config)
|
||||
app_model_config.id = str(uuid4())
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
self._session.add(app_model_config)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
|
|
@ -18,6 +20,9 @@ from services.errors.app import QuotaExceededError, WorkflowIdFormatError, Workf
|
|||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from controllers.console.app.workflow import LoopNodeRunPayload
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
@classmethod
|
||||
|
|
@ -165,7 +170,9 @@ class AppGenerateService:
|
|||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
|
||||
def generate_single_loop(
|
||||
cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True
|
||||
):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator.convert_to_event_stream(
|
||||
|
|
|
|||
|
|
@ -150,10 +150,9 @@ class AppService:
|
|||
db.session.flush()
|
||||
|
||||
if default_model_config:
|
||||
app_model_config = AppModelConfig(**default_model_config)
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
app_model_config = AppModelConfig(
|
||||
**default_model_config, app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
)
|
||||
db.session.add(app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from enums.hosted_provider import HostedTrialProvider
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
|
@ -170,6 +171,9 @@ class SystemFeatureModel(BaseModel):
|
|||
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
|
||||
enable_change_email: bool = True
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
trial_models: list[str] = []
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
|
||||
|
||||
class FeatureService:
|
||||
|
|
@ -200,7 +204,7 @@ class FeatureService:
|
|||
return knowledge_rate_limit
|
||||
|
||||
@classmethod
|
||||
def get_system_features(cls) -> SystemFeatureModel:
|
||||
def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
|
||||
system_features = SystemFeatureModel()
|
||||
|
||||
cls._fulfill_system_params_from_env(system_features)
|
||||
|
|
@ -210,7 +214,7 @@ class FeatureService:
|
|||
system_features.webapp_auth.enabled = True
|
||||
system_features.enable_change_email = False
|
||||
system_features.plugin_manager.enabled = True
|
||||
cls._fulfill_params_from_enterprise(system_features)
|
||||
cls._fulfill_params_from_enterprise(system_features, is_authenticated)
|
||||
|
||||
if dify_config.MARKETPLACE_ENABLED:
|
||||
system_features.enable_marketplace = True
|
||||
|
|
@ -225,6 +229,20 @@ class FeatureService:
|
|||
system_features.is_allow_register = dify_config.ALLOW_REGISTER
|
||||
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
|
||||
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||
system_features.trial_models = cls._fulfill_trial_models_from_env()
|
||||
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
|
||||
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
|
||||
|
||||
@classmethod
|
||||
def _fulfill_trial_models_from_env(cls) -> list[str]:
|
||||
return [
|
||||
provider.value
|
||||
for provider in HostedTrialProvider
|
||||
if (
|
||||
getattr(dify_config, f"HOSTED_{provider.config_key}_PAID_ENABLED", False)
|
||||
and getattr(dify_config, f"HOSTED_{provider.config_key}_TRIAL_ENABLED", False)
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_env(cls, features: FeatureModel):
|
||||
|
|
@ -306,7 +324,7 @@ class FeatureService:
|
|||
features.next_credit_reset_date = billing_info["next_credit_reset_date"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel):
|
||||
def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
||||
if "SSOEnforcedForSignin" in enterprise_info:
|
||||
|
|
@ -343,19 +361,14 @@ class FeatureService:
|
|||
)
|
||||
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
||||
|
||||
if "License" in enterprise_info:
|
||||
license_info = enterprise_info["License"]
|
||||
if is_authenticated and (license_info := enterprise_info.get("License")):
|
||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
|
||||
if "status" in license_info:
|
||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
|
||||
if "expiredAt" in license_info:
|
||||
features.license.expired_at = license_info["expiredAt"]
|
||||
|
||||
if "workspaces" in license_info:
|
||||
features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
|
||||
features.license.workspaces.limit = license_info["workspaces"]["limit"]
|
||||
features.license.workspaces.size = license_info["workspaces"]["used"]
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
|
||||
if "PluginInstallationPermission" in enterprise_info:
|
||||
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
|
||||
|
|
|
|||
|
|
@ -261,10 +261,9 @@ class MessageService:
|
|||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
app_model_config = AppModelConfig(
|
||||
id=conversation.app_model_config_id,
|
||||
app_id=app_model.id,
|
||||
)
|
||||
|
||||
app_model_config.id = conversation.app_model_config_id
|
||||
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
|
||||
if not app_model_config:
|
||||
raise ValueError("did not find app model config")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import AccountTrialAppRecord, TrialApp
|
||||
from services.feature_service import FeatureService
|
||||
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
||||
|
||||
|
||||
|
|
@ -20,6 +23,15 @@ class RecommendedAppService:
|
|||
)
|
||||
)
|
||||
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
apps = result["recommended_apps"]
|
||||
for app in apps:
|
||||
app_id = app["app_id"]
|
||||
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||
if trial_app_model:
|
||||
app["can_trial"] = True
|
||||
else:
|
||||
app["can_trial"] = False
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
|
@ -32,4 +44,30 @@ class RecommendedAppService:
|
|||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
app_id = result["id"]
|
||||
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
|
||||
if trial_app_model:
|
||||
result["can_trial"] = True
|
||||
else:
|
||||
result["can_trial"] = False
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def add_trial_app_record(cls, app_id: str, account_id: str):
|
||||
"""
|
||||
Add trial app record.
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
account_trial_app_record = (
|
||||
db.session.query(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
|
||||
.first()
|
||||
)
|
||||
if account_trial_app_record:
|
||||
account_trial_app_record.count += 1
|
||||
db.session.commit()
|
||||
else:
|
||||
db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
"""Workflow run retention services."""
|
||||
|
|
@ -0,0 +1,531 @@
|
|||
"""
|
||||
Archive Paid Plan Workflow Run Logs Service.
|
||||
|
||||
This service archives workflow run logs for paid plan users older than the configured
|
||||
retention period (default: 90 days) to S3-compatible storage.
|
||||
|
||||
Archived tables:
|
||||
- workflow_runs
|
||||
- workflow_app_logs
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import zipfile
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.enums import WorkflowType
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.archive_storage import (
|
||||
ArchiveStorage,
|
||||
ArchiveStorageNotConfiguredError,
|
||||
get_archive_storage,
|
||||
)
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.billing_service import BillingService
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableStats:
|
||||
"""Statistics for a single archived table."""
|
||||
|
||||
table_name: str
|
||||
row_count: int
|
||||
checksum: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchiveResult:
|
||||
"""Result of archiving a single workflow run."""
|
||||
|
||||
run_id: str
|
||||
tenant_id: str
|
||||
success: bool
|
||||
tables: list[TableStats] = field(default_factory=list)
|
||||
error: str | None = None
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchiveSummary:
|
||||
"""Summary of the entire archive operation."""
|
||||
|
||||
total_runs_processed: int = 0
|
||||
runs_archived: int = 0
|
||||
runs_skipped: int = 0
|
||||
runs_failed: int = 0
|
||||
total_elapsed_time: float = 0.0
|
||||
|
||||
|
||||
class WorkflowRunArchiver:
|
||||
"""
|
||||
Archive workflow run logs for paid plan users.
|
||||
|
||||
Storage Layout:
|
||||
{tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/
|
||||
└── archive.v1.0.zip
|
||||
├── manifest.json
|
||||
├── workflow_runs.jsonl
|
||||
├── workflow_app_logs.jsonl
|
||||
├── workflow_node_executions.jsonl
|
||||
├── workflow_node_execution_offload.jsonl
|
||||
├── workflow_pauses.jsonl
|
||||
├── workflow_pause_reasons.jsonl
|
||||
└── workflow_trigger_logs.jsonl
|
||||
"""
|
||||
|
||||
ARCHIVED_TYPE = [
|
||||
WorkflowType.WORKFLOW,
|
||||
WorkflowType.RAG_PIPELINE,
|
||||
]
|
||||
ARCHIVED_TABLES = [
|
||||
"workflow_runs",
|
||||
"workflow_app_logs",
|
||||
"workflow_node_executions",
|
||||
"workflow_node_execution_offload",
|
||||
"workflow_pauses",
|
||||
"workflow_pause_reasons",
|
||||
"workflow_trigger_logs",
|
||||
]
|
||||
|
||||
start_from: datetime.datetime | None
|
||||
end_before: datetime.datetime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
days: int = 90,
|
||||
batch_size: int = 100,
|
||||
start_from: datetime.datetime | None = None,
|
||||
end_before: datetime.datetime | None = None,
|
||||
workers: int = 1,
|
||||
tenant_ids: Sequence[str] | None = None,
|
||||
limit: int | None = None,
|
||||
dry_run: bool = False,
|
||||
delete_after_archive: bool = False,
|
||||
workflow_run_repo: APIWorkflowRunRepository | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the archiver.
|
||||
|
||||
Args:
|
||||
days: Archive runs older than this many days
|
||||
batch_size: Number of runs to process per batch
|
||||
start_from: Optional start time (inclusive) for archiving
|
||||
end_before: Optional end time (exclusive) for archiving
|
||||
workers: Number of concurrent workflow runs to archive
|
||||
tenant_ids: Optional tenant IDs for grayscale rollout
|
||||
limit: Maximum number of runs to archive (None for unlimited)
|
||||
dry_run: If True, only preview without making changes
|
||||
delete_after_archive: If True, delete runs and related data after archiving
|
||||
"""
|
||||
self.days = days
|
||||
self.batch_size = batch_size
|
||||
if start_from or end_before:
|
||||
if start_from is None or end_before is None:
|
||||
raise ValueError("start_from and end_before must be provided together")
|
||||
if start_from >= end_before:
|
||||
raise ValueError("start_from must be earlier than end_before")
|
||||
self.start_from = start_from.replace(tzinfo=datetime.UTC)
|
||||
self.end_before = end_before.replace(tzinfo=datetime.UTC)
|
||||
else:
|
||||
self.start_from = None
|
||||
self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days)
|
||||
if workers < 1:
|
||||
raise ValueError("workers must be at least 1")
|
||||
self.workers = workers
|
||||
self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else []
|
||||
self.limit = limit
|
||||
self.dry_run = dry_run
|
||||
self.delete_after_archive = delete_after_archive
|
||||
self.workflow_run_repo = workflow_run_repo
|
||||
|
||||
def run(self) -> ArchiveSummary:
|
||||
"""
|
||||
Main archiving loop.
|
||||
|
||||
Returns:
|
||||
ArchiveSummary with statistics about the operation
|
||||
"""
|
||||
summary = ArchiveSummary()
|
||||
start_time = time.time()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
self._build_start_message(),
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize archive storage (will raise if not configured)
|
||||
try:
|
||||
if not self.dry_run:
|
||||
storage = get_archive_storage()
|
||||
else:
|
||||
storage = None
|
||||
except ArchiveStorageNotConfiguredError as e:
|
||||
click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
|
||||
return summary
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
repo = self._get_workflow_run_repo()
|
||||
|
||||
def _archive_with_session(run: WorkflowRun) -> ArchiveResult:
|
||||
with session_maker() as session:
|
||||
return self._archive_run(session, storage, run)
|
||||
|
||||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
archived_count = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.workers) as executor:
|
||||
while True:
|
||||
# Check limit
|
||||
if self.limit and archived_count >= self.limit:
|
||||
click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow"))
|
||||
break
|
||||
|
||||
# Fetch batch of runs
|
||||
runs = self._get_runs_batch(last_seen)
|
||||
|
||||
if not runs:
|
||||
break
|
||||
|
||||
run_ids = [run.id for run in runs]
|
||||
with session_maker() as session:
|
||||
archived_run_ids = repo.get_archived_run_ids(session, run_ids)
|
||||
|
||||
last_seen = (runs[-1].created_at, runs[-1].id)
|
||||
|
||||
# Filter to paid tenants only
|
||||
tenant_ids = {run.tenant_id for run in runs}
|
||||
paid_tenants = self._filter_paid_tenants(tenant_ids)
|
||||
|
||||
runs_to_process: list[WorkflowRun] = []
|
||||
for run in runs:
|
||||
summary.total_runs_processed += 1
|
||||
|
||||
# Skip non-paid tenants
|
||||
if run.tenant_id not in paid_tenants:
|
||||
summary.runs_skipped += 1
|
||||
continue
|
||||
|
||||
# Skip already archived runs
|
||||
if run.id in archived_run_ids:
|
||||
summary.runs_skipped += 1
|
||||
continue
|
||||
|
||||
# Check limit
|
||||
if self.limit and archived_count + len(runs_to_process) >= self.limit:
|
||||
break
|
||||
|
||||
runs_to_process.append(run)
|
||||
|
||||
if not runs_to_process:
|
||||
continue
|
||||
|
||||
results = list(executor.map(_archive_with_session, runs_to_process))
|
||||
|
||||
for run, result in zip(runs_to_process, results):
|
||||
if result.success:
|
||||
summary.runs_archived += 1
|
||||
archived_count += 1
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} "
|
||||
f"run {run.id} (tenant={run.tenant_id}, "
|
||||
f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
summary.runs_failed += 1
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to archive run {run.id}: {result.error}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
summary.total_elapsed_time = time.time() - start_time
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: "
|
||||
f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
|
||||
f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
|
||||
f"time={summary.total_elapsed_time:.2f}s",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
def _get_runs_batch(
|
||||
self,
|
||||
last_seen: tuple[datetime.datetime, str] | None,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""Fetch a batch of workflow runs to archive."""
|
||||
repo = self._get_workflow_run_repo()
|
||||
return repo.get_runs_batch_by_time_range(
|
||||
start_from=self.start_from,
|
||||
end_before=self.end_before,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
run_types=self.ARCHIVED_TYPE,
|
||||
tenant_ids=self.tenant_ids or None,
|
||||
)
|
||||
|
||||
def _build_start_message(self) -> str:
|
||||
range_desc = f"before {self.end_before.isoformat()}"
|
||||
if self.start_from:
|
||||
range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}"
|
||||
return (
|
||||
f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving "
|
||||
f"for runs {range_desc} "
|
||||
f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})"
|
||||
)
|
||||
|
||||
def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]:
|
||||
"""Filter tenant IDs to only include paid tenants."""
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
# If billing is not enabled, treat all tenants as paid
|
||||
return tenant_ids
|
||||
|
||||
if not tenant_ids:
|
||||
return set()
|
||||
|
||||
try:
|
||||
bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids))
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch billing plans for tenants")
|
||||
# On error, skip all tenants in this batch
|
||||
return set()
|
||||
|
||||
# Filter to paid tenants (any plan except SANDBOX)
|
||||
paid = set()
|
||||
for tid, info in bulk_info.items():
|
||||
if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM):
|
||||
paid.add(tid)
|
||||
|
||||
return paid
|
||||
|
||||
def _archive_run(
|
||||
self,
|
||||
session: Session,
|
||||
storage: ArchiveStorage | None,
|
||||
run: WorkflowRun,
|
||||
) -> ArchiveResult:
|
||||
"""Archive a single workflow run."""
|
||||
start_time = time.time()
|
||||
result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
|
||||
|
||||
try:
|
||||
# Extract data from all tables
|
||||
table_data, app_logs, trigger_metadata = self._extract_data(session, run)
|
||||
|
||||
if self.dry_run:
|
||||
# In dry run, just report what would be archived
|
||||
for table_name in self.ARCHIVED_TABLES:
|
||||
records = table_data.get(table_name, [])
|
||||
result.tables.append(
|
||||
TableStats(
|
||||
table_name=table_name,
|
||||
row_count=len(records),
|
||||
checksum="",
|
||||
size_bytes=0,
|
||||
)
|
||||
)
|
||||
result.success = True
|
||||
else:
|
||||
if storage is None:
|
||||
raise ArchiveStorageNotConfiguredError("Archive storage not configured")
|
||||
archive_key = self._get_archive_key(run)
|
||||
|
||||
# Serialize tables for the archive bundle
|
||||
table_stats: list[TableStats] = []
|
||||
table_payloads: dict[str, bytes] = {}
|
||||
for table_name in self.ARCHIVED_TABLES:
|
||||
records = table_data.get(table_name, [])
|
||||
data = ArchiveStorage.serialize_to_jsonl(records)
|
||||
table_payloads[table_name] = data
|
||||
checksum = ArchiveStorage.compute_checksum(data)
|
||||
|
||||
table_stats.append(
|
||||
TableStats(
|
||||
table_name=table_name,
|
||||
row_count=len(records),
|
||||
checksum=checksum,
|
||||
size_bytes=len(data),
|
||||
)
|
||||
)
|
||||
|
||||
# Generate and upload archive bundle
|
||||
manifest = self._generate_manifest(run, table_stats)
|
||||
manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8")
|
||||
archive_data = self._build_archive_bundle(manifest_data, table_payloads)
|
||||
storage.put_object(archive_key, archive_data)
|
||||
|
||||
repo = self._get_workflow_run_repo()
|
||||
archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata)
|
||||
session.commit()
|
||||
|
||||
deleted_counts = None
|
||||
if self.delete_after_archive:
|
||||
deleted_counts = repo.delete_runs_with_related(
|
||||
[run],
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s",
|
||||
run.id,
|
||||
{s.table_name: s.row_count for s in table_stats},
|
||||
archived_log_count,
|
||||
deleted_counts,
|
||||
)
|
||||
|
||||
result.tables = table_stats
|
||||
result.success = True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to archive workflow run %s", run.id)
|
||||
result.error = str(e)
|
||||
session.rollback()
|
||||
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
def _extract_data(
|
||||
self,
|
||||
session: Session,
|
||||
run: WorkflowRun,
|
||||
) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]:
|
||||
table_data: dict[str, list[dict[str, Any]]] = {}
|
||||
table_data["workflow_runs"] = [self._row_to_dict(run)]
|
||||
repo = self._get_workflow_run_repo()
|
||||
app_logs = repo.get_app_logs_by_run_id(session, run.id)
|
||||
table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs]
|
||||
node_exec_repo = self._get_workflow_node_execution_repo(session)
|
||||
node_exec_records = node_exec_repo.get_executions_by_workflow_run(
|
||||
tenant_id=run.tenant_id,
|
||||
app_id=run.app_id,
|
||||
workflow_run_id=run.id,
|
||||
)
|
||||
node_exec_ids = [record.id for record in node_exec_records]
|
||||
offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids)
|
||||
table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records]
|
||||
table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records]
|
||||
repo = self._get_workflow_run_repo()
|
||||
pause_records = repo.get_pause_records_by_run_id(session, run.id)
|
||||
pause_ids = [pause.id for pause in pause_records]
|
||||
pause_reason_records = repo.get_pause_reason_records_by_run_id(
|
||||
session,
|
||||
pause_ids,
|
||||
)
|
||||
table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records]
|
||||
table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records]
|
||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_records = trigger_repo.list_by_run_id(run.id)
|
||||
table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records]
|
||||
trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None
|
||||
return table_data, app_logs, trigger_metadata
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: Any) -> dict[str, Any]:
|
||||
mapper = inspect(row).mapper
|
||||
return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns}
|
||||
|
||||
def _get_archive_key(self, run: WorkflowRun) -> str:
|
||||
"""Get the storage key for the archive bundle."""
|
||||
created_at = run.created_at
|
||||
prefix = (
|
||||
f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
|
||||
f"month={created_at.strftime('%m')}/workflow_run_id={run.id}"
|
||||
)
|
||||
return f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
|
||||
|
||||
def _generate_manifest(
|
||||
self,
|
||||
run: WorkflowRun,
|
||||
table_stats: list[TableStats],
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a manifest for the archived workflow run."""
|
||||
return {
|
||||
"schema_version": ARCHIVE_SCHEMA_VERSION,
|
||||
"workflow_run_id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"app_id": run.app_id,
|
||||
"workflow_id": run.workflow_id,
|
||||
"created_at": run.created_at.isoformat(),
|
||||
"archived_at": datetime.datetime.now(datetime.UTC).isoformat(),
|
||||
"tables": {
|
||||
stat.table_name: {
|
||||
"row_count": stat.row_count,
|
||||
"checksum": stat.checksum,
|
||||
"size_bytes": stat.size_bytes,
|
||||
}
|
||||
for stat in table_stats
|
||||
},
|
||||
}
|
||||
|
||||
def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes:
|
||||
buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
|
||||
archive.writestr("manifest.json", manifest_data)
|
||||
for table_name in self.ARCHIVED_TABLES:
|
||||
data = table_payloads.get(table_name)
|
||||
if data is None:
|
||||
raise ValueError(f"Missing archive payload for {table_name}")
|
||||
archive.writestr(f"{table_name}.jsonl", data)
|
||||
return buffer.getvalue()
|
||||
|
||||
def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
|
||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
return trigger_repo.delete_by_run_ids(run_ids)
|
||||
|
||||
def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_ids = [run.id for run in runs]
|
||||
return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids)
|
||||
|
||||
def _get_workflow_node_execution_repo(
|
||||
self,
|
||||
session: Session,
|
||||
) -> DifyAPIWorkflowNodeExecutionRepository:
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False)
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
||||
|
||||
def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
|
||||
if self.workflow_run_repo is not None:
|
||||
return self.workflow_run_repo
|
||||
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return self.workflow_run_repo
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
ARCHIVE_SCHEMA_VERSION = "1.0"
|
||||
ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip"
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Delete Archived Workflow Run Service.
|
||||
|
||||
This service deletes archived workflow run data from the database while keeping
|
||||
archive logs intact.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteResult:
|
||||
run_id: str
|
||||
tenant_id: str
|
||||
success: bool
|
||||
deleted_counts: dict[str, int] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
class ArchivedWorkflowRunDeletion:
|
||||
def __init__(self, dry_run: bool = False):
|
||||
self.dry_run = dry_run
|
||||
self.workflow_run_repo: APIWorkflowRunRepository | None = None
|
||||
|
||||
def delete_by_run_id(self, run_id: str) -> DeleteResult:
|
||||
start_time = time.time()
|
||||
result = DeleteResult(run_id=run_id, tenant_id="", success=False)
|
||||
|
||||
repo = self._get_workflow_run_repo()
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
with session_maker() as session:
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
if not run:
|
||||
result.error = f"Workflow run {run_id} not found"
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
result.tenant_id = run.tenant_id
|
||||
if not repo.get_archived_run_ids(session, [run.id]):
|
||||
result.error = f"Workflow run {run_id} is not archived"
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
result = self._delete_run(run)
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
def delete_batch(
|
||||
self,
|
||||
tenant_ids: list[str] | None,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
) -> list[DeleteResult]:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
results: list[DeleteResult] = []
|
||||
|
||||
repo = self._get_workflow_run_repo()
|
||||
with session_maker() as session:
|
||||
runs = list(
|
||||
repo.get_archived_runs_by_time_range(
|
||||
session=session,
|
||||
tenant_ids=tenant_ids,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
)
|
||||
for run in runs:
|
||||
results.append(self._delete_run(run))
|
||||
|
||||
return results
|
||||
|
||||
def _delete_run(self, run: WorkflowRun) -> DeleteResult:
|
||||
start_time = time.time()
|
||||
result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
|
||||
if self.dry_run:
|
||||
result.success = True
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
repo = self._get_workflow_run_repo()
|
||||
try:
|
||||
deleted_counts = repo.delete_runs_with_related(
|
||||
[run],
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
result.deleted_counts = deleted_counts
|
||||
result.success = True
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int:
|
||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
return trigger_repo.delete_by_run_ids(run_ids)
|
||||
|
||||
@staticmethod
|
||||
def _delete_node_executions(
|
||||
session: Session,
|
||||
runs: Sequence[WorkflowRun],
|
||||
) -> tuple[int, int]:
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
run_ids = [run.id for run in runs]
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
|
||||
)
|
||||
return repo.delete_by_runs(session, run_ids)
|
||||
|
||||
def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
|
||||
if self.workflow_run_repo is not None:
|
||||
return self.workflow_run_repo
|
||||
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
|
||||
sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
return self.workflow_run_repo
|
||||
|
|
@ -0,0 +1,481 @@
|
|||
"""
|
||||
Restore Archived Workflow Run Service.
|
||||
|
||||
This service restores archived workflow run data from S3-compatible storage
|
||||
back to the database.
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.archive_storage import (
|
||||
ArchiveStorage,
|
||||
ArchiveStorageNotConfiguredError,
|
||||
get_archive_storage,
|
||||
)
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import (
|
||||
WorkflowAppLog,
|
||||
WorkflowArchiveLog,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowPause,
|
||||
WorkflowPauseReason,
|
||||
WorkflowRun,
|
||||
)
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping of table names to SQLAlchemy models
|
||||
TABLE_MODELS = {
|
||||
"workflow_runs": WorkflowRun,
|
||||
"workflow_app_logs": WorkflowAppLog,
|
||||
"workflow_node_executions": WorkflowNodeExecutionModel,
|
||||
"workflow_node_execution_offload": WorkflowNodeExecutionOffload,
|
||||
"workflow_pauses": WorkflowPause,
|
||||
"workflow_pause_reasons": WorkflowPauseReason,
|
||||
"workflow_trigger_logs": WorkflowTriggerLog,
|
||||
}
|
||||
|
||||
SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]]
|
||||
|
||||
SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = {
|
||||
"1.0": {},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RestoreResult:
|
||||
"""Result of restoring a single workflow run."""
|
||||
|
||||
run_id: str
|
||||
tenant_id: str
|
||||
success: bool
|
||||
restored_counts: dict[str, int]
|
||||
error: str | None = None
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
class WorkflowRunRestore:
|
||||
"""
|
||||
Restore archived workflow run data from storage to database.
|
||||
|
||||
This service reads archived data from storage and restores it to the
|
||||
database tables. It handles idempotency by skipping records that already
|
||||
exist in the database.
|
||||
"""
|
||||
|
||||
def __init__(self, dry_run: bool = False, workers: int = 1):
|
||||
"""
|
||||
Initialize the restore service.
|
||||
|
||||
Args:
|
||||
dry_run: If True, only preview without making changes
|
||||
workers: Number of concurrent workflow runs to restore
|
||||
"""
|
||||
self.dry_run = dry_run
|
||||
if workers < 1:
|
||||
raise ValueError("workers must be at least 1")
|
||||
self.workers = workers
|
||||
self.workflow_run_repo: APIWorkflowRunRepository | None = None
|
||||
|
||||
def _restore_from_run(
|
||||
self,
|
||||
run: WorkflowRun | WorkflowArchiveLog,
|
||||
*,
|
||||
session_maker: sessionmaker,
|
||||
) -> RestoreResult:
|
||||
start_time = time.time()
|
||||
run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id
|
||||
created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at
|
||||
result = RestoreResult(
|
||||
run_id=run_id,
|
||||
tenant_id=run.tenant_id,
|
||||
success=False,
|
||||
restored_counts={},
|
||||
)
|
||||
|
||||
if not self.dry_run:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
storage = get_archive_storage()
|
||||
except ArchiveStorageNotConfiguredError as e:
|
||||
result.error = str(e)
|
||||
click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
prefix = (
|
||||
f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
|
||||
f"month={created_at.strftime('%m')}/workflow_run_id={run_id}"
|
||||
)
|
||||
archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
|
||||
try:
|
||||
archive_data = storage.get_object(archive_key)
|
||||
except FileNotFoundError:
|
||||
result.error = f"Archive bundle not found: {archive_key}"
|
||||
click.echo(click.style(result.error, fg="red"))
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
with session_maker() as session:
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive:
|
||||
try:
|
||||
manifest = self._load_manifest_from_zip(archive)
|
||||
except ValueError as e:
|
||||
result.error = f"Archive bundle invalid: {e}"
|
||||
click.echo(click.style(result.error, fg="red"))
|
||||
return result
|
||||
|
||||
tables = manifest.get("tables", {})
|
||||
schema_version = self._get_schema_version(manifest)
|
||||
for table_name, info in tables.items():
|
||||
row_count = info.get("row_count", 0)
|
||||
if row_count == 0:
|
||||
result.restored_counts[table_name] = 0
|
||||
continue
|
||||
|
||||
if self.dry_run:
|
||||
result.restored_counts[table_name] = row_count
|
||||
continue
|
||||
|
||||
member_path = f"{table_name}.jsonl"
|
||||
try:
|
||||
data = archive.read(member_path)
|
||||
except KeyError:
|
||||
click.echo(
|
||||
click.style(
|
||||
f" Warning: Table data not found in archive: {member_path}",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
result.restored_counts[table_name] = 0
|
||||
continue
|
||||
|
||||
records = ArchiveStorage.deserialize_from_jsonl(data)
|
||||
restored = self._restore_table_records(
|
||||
session,
|
||||
table_name,
|
||||
records,
|
||||
schema_version=schema_version,
|
||||
)
|
||||
result.restored_counts[table_name] = restored
|
||||
if not self.dry_run:
|
||||
click.echo(
|
||||
click.style(
|
||||
f" Restored {restored}/{len(records)} records to {table_name}",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
# Verify row counts match manifest
|
||||
manifest_total = sum(info.get("row_count", 0) for info in tables.values())
|
||||
restored_total = sum(result.restored_counts.values())
|
||||
|
||||
if not self.dry_run:
|
||||
# Note: restored count might be less than manifest count if records already exist
|
||||
logger.info(
|
||||
"Restore verification: manifest_total=%d, restored_total=%d",
|
||||
manifest_total,
|
||||
restored_total,
|
||||
)
|
||||
|
||||
# Delete the archive log record after successful restore
|
||||
repo = self._get_workflow_run_repo()
|
||||
repo.delete_archive_log_by_run_id(session, run_id)
|
||||
|
||||
session.commit()
|
||||
|
||||
result.success = True
|
||||
if not self.dry_run:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Completed restore for workflow run {run_id}: restored={result.restored_counts}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to restore workflow run %s", run_id)
|
||||
result.error = str(e)
|
||||
session.rollback()
|
||||
click.echo(click.style(f"Restore failed: {e}", fg="red"))
|
||||
|
||||
result.elapsed_time = time.time() - start_time
|
||||
return result
|
||||
|
||||
def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
|
||||
if self.workflow_run_repo is not None:
|
||||
return self.workflow_run_repo
|
||||
|
||||
self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
|
||||
sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
return self.workflow_run_repo
|
||||
|
||||
@staticmethod
|
||||
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
|
||||
try:
|
||||
data = archive.read("manifest.json")
|
||||
except KeyError as e:
|
||||
raise ValueError("manifest.json missing from archive bundle") from e
|
||||
return json.loads(data.decode("utf-8"))
|
||||
|
||||
def _restore_table_records(
|
||||
self,
|
||||
session: Session,
|
||||
table_name: str,
|
||||
records: list[dict[str, Any]],
|
||||
*,
|
||||
schema_version: str,
|
||||
) -> int:
|
||||
"""
|
||||
Restore records to a table.
|
||||
|
||||
Uses INSERT ... ON CONFLICT DO NOTHING for idempotency.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
table_name: Name of the table
|
||||
records: List of record dictionaries
|
||||
schema_version: Archived schema version from manifest
|
||||
|
||||
Returns:
|
||||
Number of records actually inserted
|
||||
"""
|
||||
if not records:
|
||||
return 0
|
||||
|
||||
model = TABLE_MODELS.get(table_name)
|
||||
if not model:
|
||||
logger.warning("Unknown table: %s", table_name)
|
||||
return 0
|
||||
|
||||
column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model)
|
||||
unknown_fields: set[str] = set()
|
||||
|
||||
# Apply schema mapping, filter to current columns, then convert datetimes
|
||||
converted_records = []
|
||||
for record in records:
|
||||
mapped = self._apply_schema_mapping(table_name, schema_version, record)
|
||||
unknown_fields.update(set(mapped.keys()) - column_names)
|
||||
filtered = {key: value for key, value in mapped.items() if key in column_names}
|
||||
for key in non_nullable_with_default:
|
||||
if key in filtered and filtered[key] is None:
|
||||
filtered.pop(key)
|
||||
missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None]
|
||||
if missing_required:
|
||||
missing_cols = ", ".join(sorted(missing_required))
|
||||
raise ValueError(
|
||||
f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}"
|
||||
)
|
||||
converted = self._convert_datetime_fields(filtered, model)
|
||||
converted_records.append(converted)
|
||||
if unknown_fields:
|
||||
logger.warning(
|
||||
"Dropped unknown columns for %s (schema_version=%s): %s",
|
||||
table_name,
|
||||
schema_version,
|
||||
", ".join(sorted(unknown_fields)),
|
||||
)
|
||||
|
||||
# Use INSERT ... ON CONFLICT DO NOTHING for idempotency
|
||||
stmt = pg_insert(model).values(converted_records)
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=["id"])
|
||||
|
||||
result = session.execute(stmt)
|
||||
return cast(CursorResult, result).rowcount or 0
|
||||
|
||||
def _convert_datetime_fields(
|
||||
self,
|
||||
record: dict[str, Any],
|
||||
model: type[DeclarativeBase] | Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Convert ISO datetime strings to datetime objects."""
|
||||
from sqlalchemy import DateTime
|
||||
|
||||
result = dict(record)
|
||||
|
||||
for column in model.__table__.columns:
|
||||
if isinstance(column.type, DateTime):
|
||||
value = result.get(column.key)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
result[column.key] = datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
def _get_schema_version(self, manifest: dict[str, Any]) -> str:
|
||||
schema_version = manifest.get("schema_version")
|
||||
if not schema_version:
|
||||
logger.warning("Manifest missing schema_version; defaulting to 1.0")
|
||||
schema_version = "1.0"
|
||||
schema_version = str(schema_version)
|
||||
if schema_version not in SCHEMA_MAPPERS:
|
||||
raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.")
|
||||
return schema_version
|
||||
|
||||
def _apply_schema_mapping(
|
||||
self,
|
||||
table_name: str,
|
||||
schema_version: str,
|
||||
record: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
# Keep hook for forward/backward compatibility when schema evolves.
|
||||
mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name)
|
||||
if mapper is None:
|
||||
return dict(record)
|
||||
return mapper(record)
|
||||
|
||||
def _get_model_column_info(
|
||||
self,
|
||||
model: type[DeclarativeBase] | Any,
|
||||
) -> tuple[set[str], set[str], set[str]]:
|
||||
columns = list(model.__table__.columns)
|
||||
column_names = {column.key for column in columns}
|
||||
required_columns = {
|
||||
column.key
|
||||
for column in columns
|
||||
if not column.nullable
|
||||
and column.default is None
|
||||
and column.server_default is None
|
||||
and not column.autoincrement
|
||||
}
|
||||
non_nullable_with_default = {
|
||||
column.key
|
||||
for column in columns
|
||||
if not column.nullable
|
||||
and (column.default is not None or column.server_default is not None or column.autoincrement)
|
||||
}
|
||||
return column_names, required_columns, non_nullable_with_default
|
||||
|
||||
def restore_batch(
|
||||
self,
|
||||
tenant_ids: list[str] | None,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
) -> list[RestoreResult]:
|
||||
"""
|
||||
Restore multiple workflow runs by time range.
|
||||
|
||||
Args:
|
||||
tenant_ids: Optional tenant IDs
|
||||
start_date: Start date filter
|
||||
end_date: End date filter
|
||||
limit: Maximum number of runs to restore (default: 100)
|
||||
|
||||
Returns:
|
||||
List of RestoreResult objects
|
||||
"""
|
||||
results: list[RestoreResult] = []
|
||||
if tenant_ids is not None and not tenant_ids:
|
||||
return results
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
repo = self._get_workflow_run_repo()
|
||||
|
||||
with session_maker() as session:
|
||||
archive_logs = repo.get_archived_logs_by_time_range(
|
||||
session=session,
|
||||
tenant_ids=tenant_ids,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Found {len(archive_logs)} archived workflow runs to restore",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
|
||||
def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult:
|
||||
return self._restore_from_run(
|
||||
archive_log,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.workers) as executor:
|
||||
results = list(executor.map(_restore_with_session, archive_logs))
|
||||
|
||||
total_counts: dict[str, int] = {}
|
||||
for result in results:
|
||||
for table_name, count in result.restored_counts.items():
|
||||
total_counts[table_name] = total_counts.get(table_name, 0) + count
|
||||
success_count = sum(1 for result in results if result.success)
|
||||
|
||||
if self.dry_run:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def restore_by_run_id(
|
||||
self,
|
||||
run_id: str,
|
||||
) -> RestoreResult:
|
||||
"""
|
||||
Restore a single workflow run by run ID.
|
||||
"""
|
||||
repo = self._get_workflow_run_repo()
|
||||
archive_log = repo.get_archived_log_by_run_id(run_id)
|
||||
|
||||
if not archive_log:
|
||||
click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red"))
|
||||
return RestoreResult(
|
||||
run_id=run_id,
|
||||
tenant_id="",
|
||||
success=False,
|
||||
restored_counts={},
|
||||
error=f"Workflow run archive {run_id} not found",
|
||||
)
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
result = self._restore_from_run(archive_log, session_maker=session_maker)
|
||||
if self.dry_run and result.success:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
|
@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
|
||||
from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
from models.enums import AppTriggerType, CreatorUserRole
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
|
@ -173,7 +173,80 @@ class WorkflowAppService:
|
|||
"data": items,
|
||||
}
|
||||
|
||||
def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]:
|
||||
def get_paginate_workflow_archive_logs(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
):
|
||||
"""
|
||||
Get paginate workflow archive logs using SQLAlchemy 2.0 style.
|
||||
"""
|
||||
stmt = select(WorkflowArchiveLog).where(
|
||||
WorkflowArchiveLog.tenant_id == app_model.tenant_id,
|
||||
WorkflowArchiveLog.app_id == app_model.id,
|
||||
WorkflowArchiveLog.log_id.isnot(None),
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc())
|
||||
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = session.scalar(count_stmt) or 0
|
||||
|
||||
offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
|
||||
|
||||
logs = list(session.scalars(offset_stmt).all())
|
||||
account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT}
|
||||
end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER}
|
||||
|
||||
accounts_by_id = {}
|
||||
if account_ids:
|
||||
accounts_by_id = {
|
||||
account.id: account
|
||||
for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all()
|
||||
}
|
||||
|
||||
end_users_by_id = {}
|
||||
if end_user_ids:
|
||||
end_users_by_id = {
|
||||
end_user.id: end_user
|
||||
for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all()
|
||||
}
|
||||
|
||||
items = []
|
||||
for log in logs:
|
||||
if log.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
created_by_account = accounts_by_id.get(log.created_by)
|
||||
created_by_end_user = None
|
||||
elif log.created_by_role == CreatorUserRole.END_USER:
|
||||
created_by_account = None
|
||||
created_by_end_user = end_users_by_id.get(log.created_by)
|
||||
else:
|
||||
created_by_account = None
|
||||
created_by_end_user = None
|
||||
|
||||
items.append(
|
||||
{
|
||||
"id": log.id,
|
||||
"workflow_run": log.workflow_run_summary,
|
||||
"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata),
|
||||
"created_by_account": created_by_account,
|
||||
"created_by_end_user": created_by_end_user,
|
||||
"created_at": log.log_created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"has_more": total > page * limit,
|
||||
"data": items,
|
||||
}
|
||||
|
||||
def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]:
|
||||
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
|
||||
if not metadata:
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -11,8 +11,10 @@ from sqlalchemy.engine import CursorResult
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
|
||||
from models import (
|
||||
ApiToken,
|
||||
AppAnnotationHitHistory,
|
||||
|
|
@ -43,6 +45,7 @@ from models.workflow import (
|
|||
ConversationVariable,
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowArchiveLog,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
|
@ -67,6 +70,9 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
|||
_delete_app_workflow_runs(tenant_id, app_id)
|
||||
_delete_app_workflow_node_executions(tenant_id, app_id)
|
||||
_delete_app_workflow_app_logs(tenant_id, app_id)
|
||||
if dify_config.BILLING_ENABLED and dify_config.ARCHIVE_STORAGE_ENABLED:
|
||||
_delete_app_workflow_archive_logs(tenant_id, app_id)
|
||||
_delete_archived_workflow_run_files(tenant_id, app_id)
|
||||
_delete_app_conversations(tenant_id, app_id)
|
||||
_delete_app_messages(tenant_id, app_id)
|
||||
_delete_workflow_tool_providers(tenant_id, app_id)
|
||||
|
|
@ -252,6 +258,45 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
|||
)
|
||||
|
||||
|
||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_archive_log(workflow_archive_log_id: str):
|
||||
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
_delete_records(
|
||||
"""select id from workflow_archive_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
|
||||
{"tenant_id": tenant_id, "app_id": app_id},
|
||||
del_workflow_archive_log,
|
||||
"workflow archive log",
|
||||
)
|
||||
|
||||
|
||||
def _delete_archived_workflow_run_files(tenant_id: str, app_id: str):
|
||||
prefix = f"{tenant_id}/app_id={app_id}/"
|
||||
try:
|
||||
archive_storage = get_archive_storage()
|
||||
except ArchiveStorageNotConfiguredError as e:
|
||||
logger.info("Archive storage not configured, skipping archive file cleanup: %s", e)
|
||||
return
|
||||
|
||||
try:
|
||||
keys = archive_storage.list_objects(prefix)
|
||||
except Exception:
|
||||
logger.exception("Failed to list archive files for app %s", app_id)
|
||||
return
|
||||
|
||||
deleted = 0
|
||||
for key in keys:
|
||||
try:
|
||||
archive_storage.delete_object(key)
|
||||
deleted += 1
|
||||
except Exception:
|
||||
logger.exception("Failed to delete archive object %s", key)
|
||||
|
||||
logger.info("Deleted %s archive objects for app %s", deleted, app_id)
|
||||
|
||||
|
||||
def _delete_app_conversations(tenant_id: str, app_id: str):
|
||||
def del_conversation(session, conversation_id: str):
|
||||
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ import pytest
|
|||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ from urllib.parse import urlencode
|
|||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
|
|
|||
|
|
@ -5,13 +5,13 @@ from collections.abc import Generator
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import uuid
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ import uuid
|
|||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import uuid
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
|
|
|||
|
|
@ -172,7 +172,6 @@ class TestAgentService:
|
|||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
id=fake.uuid4(),
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
|
|
@ -180,6 +179,7 @@ class TestAgentService:
|
|||
model="gpt-3.5-turbo",
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -413,7 +413,6 @@ class TestAgentService:
|
|||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
id=fake.uuid4(),
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
|
|
@ -421,6 +420,7 @@ class TestAgentService:
|
|||
model="gpt-3.5-turbo",
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
|
|
@ -485,7 +485,6 @@ class TestAgentService:
|
|||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
id=fake.uuid4(),
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
|
|
@ -493,6 +492,7 @@ class TestAgentService:
|
|||
model="gpt-3.5-turbo",
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -220,6 +220,23 @@ class TestAnnotationService:
|
|||
# Note: In this test, no annotation setting exists, so task should not be called
|
||||
mock_external_service_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_insert_app_annotation_directly_requires_question(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Question must be provided when inserting annotations directly.
|
||||
"""
|
||||
fake = Faker()
|
||||
app, _ = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
annotation_args = {
|
||||
"question": None,
|
||||
"answer": fake.text(max_nb_chars=200),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id)
|
||||
|
||||
def test_insert_app_annotation_directly_app_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
|
|
|||
|
|
@ -226,26 +226,27 @@ class TestAppDslService:
|
|||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create model config for the app
|
||||
model_config = AppModelConfig()
|
||||
model_config.id = fake.uuid4()
|
||||
model_config.app_id = app.id
|
||||
model_config.provider = "openai"
|
||||
model_config.model_id = "gpt-3.5-turbo"
|
||||
model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
model_config = AppModelConfig(
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
model=json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}
|
||||
),
|
||||
pre_prompt="You are a helpful assistant.",
|
||||
prompt_type="simple",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
model_config.pre_prompt = "You are a helpful assistant."
|
||||
model_config.prompt_type = "simple"
|
||||
model_config.created_by = account.id
|
||||
model_config.updated_by = account.id
|
||||
model_config.id = fake.uuid4()
|
||||
|
||||
# Set the app_model_config_id to link the config
|
||||
app.app_model_config_id = model_config.id
|
||||
|
|
|
|||
|
|
@ -4,7 +4,13 @@ import pytest
|
|||
from faker import Faker
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel
|
||||
from services.feature_service import (
|
||||
FeatureModel,
|
||||
FeatureService,
|
||||
KnowledgeRateLimitModel,
|
||||
LicenseStatus,
|
||||
SystemFeatureModel,
|
||||
)
|
||||
|
||||
|
||||
class TestFeatureService:
|
||||
|
|
@ -274,7 +280,7 @@ class TestFeatureService:
|
|||
mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = FeatureService.get_system_features()
|
||||
result = FeatureService.get_system_features(is_authenticated=True)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
|
|
@ -324,6 +330,61 @@ class TestFeatureService:
|
|||
# Verify mock interactions
|
||||
mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once()
|
||||
|
||||
def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test system features retrieval for an unauthenticated user.
|
||||
|
||||
This test verifies that:
|
||||
- The response payload is minimized (e.g., verbose license details are excluded).
|
||||
- Essential UI configuration (Branding, SSO, Marketplace) remains available.
|
||||
- The response structure adheres to the public schema for unauthenticated clients.
|
||||
"""
|
||||
# Arrange: Setup test data with exact same config as success test
|
||||
with patch("services.feature_service.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_config.MARKETPLACE_ENABLED = True
|
||||
mock_config.ENABLE_EMAIL_CODE_LOGIN = True
|
||||
mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True
|
||||
mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False
|
||||
mock_config.ALLOW_REGISTER = False
|
||||
mock_config.ALLOW_CREATE_WORKSPACE = False
|
||||
mock_config.MAIL_TYPE = "smtp"
|
||||
mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100
|
||||
|
||||
# Act: Execute with is_authenticated=False
|
||||
result = FeatureService.get_system_features(is_authenticated=False)
|
||||
|
||||
# Assert: Basic structure
|
||||
assert result is not None
|
||||
assert isinstance(result, SystemFeatureModel)
|
||||
|
||||
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
|
||||
# Ensure only essential UI flags are returned to unauthenticated clients
|
||||
# to keep the payload lightweight and adhere to architectural boundaries.
|
||||
assert result.license.status == LicenseStatus.NONE
|
||||
assert result.license.expired_at == ""
|
||||
assert result.license.workspaces.enabled is False
|
||||
assert result.license.workspaces.limit == 0
|
||||
assert result.license.workspaces.size == 0
|
||||
|
||||
# --- 2. Verify Public UI Configuration Availability ---
|
||||
# Ensure that data required for frontend rendering remains accessible.
|
||||
|
||||
# Branding should match the mock data
|
||||
assert result.branding.enabled is True
|
||||
assert result.branding.application_title == "Test Enterprise"
|
||||
assert result.branding.login_page_logo == "https://example.com/logo.png"
|
||||
|
||||
# SSO settings should be visible for login page rendering
|
||||
assert result.sso_enforced_for_signin is True
|
||||
assert result.sso_enforced_for_signin_protocol == "saml"
|
||||
|
||||
# General auth settings should be visible
|
||||
assert result.enable_email_code_login is True
|
||||
|
||||
# Marketplace should be visible
|
||||
assert result.enable_marketplace is True
|
||||
|
||||
def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test system features retrieval with basic configuration (no enterprise).
|
||||
|
|
@ -1031,7 +1092,7 @@ class TestFeatureService:
|
|||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = FeatureService.get_system_features()
|
||||
result = FeatureService.get_system_features(is_authenticated=True)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
|
|
@ -1400,7 +1461,7 @@ class TestFeatureService:
|
|||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = FeatureService.get_system_features()
|
||||
result = FeatureService.get_system_features(is_authenticated=True)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -925,24 +925,24 @@ class TestWorkflowService:
|
|||
# Create app model config (required for conversion)
|
||||
from models.model import AppModelConfig
|
||||
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config.id = fake.uuid4()
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.tenant_id = app.tenant_id
|
||||
app_model_config.provider = "openai"
|
||||
app_model_config.model_id = "gpt-3.5-turbo"
|
||||
# Set the model field directly - this is what model_dict property returns
|
||||
app_model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"completion_params": {"max_tokens": 1000, "temperature": 0.7},
|
||||
}
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
# Set the model field directly - this is what model_dict property returns
|
||||
model=json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"completion_params": {"max_tokens": 1000, "temperature": 0.7},
|
||||
}
|
||||
),
|
||||
# Set pre_prompt for PromptTemplateConfigManager
|
||||
pre_prompt="You are a helpful assistant.",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
# Set pre_prompt for PromptTemplateConfigManager
|
||||
app_model_config.pre_prompt = "You are a helpful assistant."
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
app_model_config.id = fake.uuid4()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
@ -987,24 +987,24 @@ class TestWorkflowService:
|
|||
# Create app model config (required for conversion)
|
||||
from models.model import AppModelConfig
|
||||
|
||||
app_model_config = AppModelConfig()
|
||||
app_model_config.id = fake.uuid4()
|
||||
app_model_config.app_id = app.id
|
||||
app_model_config.tenant_id = app.tenant_id
|
||||
app_model_config.provider = "openai"
|
||||
app_model_config.model_id = "gpt-3.5-turbo"
|
||||
# Set the model field directly - this is what model_dict property returns
|
||||
app_model_config.model = json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"completion_params": {"max_tokens": 1000, "temperature": 0.7},
|
||||
}
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id,
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
# Set the model field directly - this is what model_dict property returns
|
||||
model=json.dumps(
|
||||
{
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"completion_params": {"max_tokens": 1000, "temperature": 0.7},
|
||||
}
|
||||
),
|
||||
# Set pre_prompt for PromptTemplateConfigManager
|
||||
pre_prompt="Complete the following text:",
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
# Set pre_prompt for PromptTemplateConfigManager
|
||||
app_model_config.pre_prompt = "Complete the following text:"
|
||||
app_model_config.created_by = account.id
|
||||
app_model_config.updated_by = account.id
|
||||
app_model_config.id = fake.uuid4()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
def test_console_ping_fastopenapi_returns_pong(app: Flask):
|
||||
ext_fastopenapi.init_app(app)
|
||||
|
||||
client = app.test_client()
|
||||
response = client.get("/console/api/ping")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "pong"}
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
"""Tests for execution context module."""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
|
@ -149,6 +151,54 @@ class TestExecutionContext:
|
|||
|
||||
assert ctx.user == user
|
||||
|
||||
def test_thread_safe_context_manager(self):
|
||||
"""Test shared ExecutionContext works across threads without token mismatch."""
|
||||
test_var = contextvars.ContextVar("thread_safe_test_var")
|
||||
|
||||
class TrackingAppContext(AppContext):
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
return default
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def enter(self):
|
||||
token = test_var.set(threading.get_ident())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
test_var.reset(token)
|
||||
|
||||
ctx = ExecutionContext(app_context=TrackingAppContext())
|
||||
errors: list[Exception] = []
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for _ in range(20):
|
||||
with ctx:
|
||||
try:
|
||||
barrier.wait()
|
||||
barrier.wait()
|
||||
except threading.BrokenBarrierError:
|
||||
return
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
try:
|
||||
barrier.abort()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t1 = threading.Thread(target=worker)
|
||||
t2 = threading.Thread(target=worker)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join(timeout=5)
|
||||
t2.join(timeout=5)
|
||||
|
||||
assert not errors
|
||||
|
||||
|
||||
class TestIExecutionContextProtocol:
|
||||
"""Test IExecutionContext protocol."""
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request
|
|||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
|
||||
from .test_mock_nodes import (
|
||||
MockAgentNode,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from unittest.mock import patch
|
|||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
|
|
@ -26,7 +27,6 @@ from core.workflow.graph_events import (
|
|||
)
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.tools.utils.yaml_utils import _load_yaml_file
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
|
|
@ -38,7 +39,6 @@ from core.workflow.graph_events import (
|
|||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import uuid
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
|
|
|
|||
|
|
@ -475,3 +475,130 @@ def test_valid_api_key_works():
|
|||
headers = executor._assembling_headers()
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == "Bearer valid-api-key-123"
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_unquoted_uuid_variable():
|
||||
"""Test that unquoted UUID variables are correctly handled in JSON body.
|
||||
|
||||
This test verifies the fix for issue #31436 where json_repair would truncate
|
||||
certain UUID patterns (like 57eeeeb1-...) when they appeared as unquoted values.
|
||||
"""
|
||||
# UUID that triggers the json_repair truncation bug
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "uuid"], test_uuid)
|
||||
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Unquoted UUID Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
# UUID variable without quotes - this is the problematic case
|
||||
value='{"rowId": {{#pre_node_id.uuid#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# The UUID should be preserved in full, not truncated
|
||||
assert executor.json == {"rowId": test_uuid}
|
||||
assert len(executor.json["rowId"]) == len(test_uuid)
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_unquoted_uuid_with_newlines():
|
||||
"""Test that unquoted UUID variables with newlines in JSON are handled correctly.
|
||||
|
||||
This is a specific case from issue #31436 where the JSON body contains newlines.
|
||||
"""
|
||||
test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2"
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "uuid"], test_uuid)
|
||||
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Unquoted UUID and Newlines",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
# JSON with newlines and unquoted UUID variable
|
||||
value='{\n"rowId": {{#pre_node_id.uuid#}}\n}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# The UUID should be preserved in full
|
||||
assert executor.json == {"rowId": test_uuid}
|
||||
|
||||
|
||||
def test_executor_with_json_body_preserves_numbers_and_strings():
|
||||
"""Test that numbers are preserved and string values are properly quoted."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["node", "count"], 42)
|
||||
variable_pool.add(["node", "id"], "abc-123")
|
||||
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with mixed types",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"count": {{#node.count#}}, "id": {{#node.id#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
assert executor.json["count"] == 42
|
||||
assert executor.json["id"] == "abc-123"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from unittest.mock import MagicMock, Mock
|
|||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
|
|
@ -12,7 +13,6 @@ from core.workflow.enums import WorkflowNodeExecutionStatus
|
|||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ import uuid
|
|||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@ import uuid
|
|||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.variables import ArrayStringVariable
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue