mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
Merge branch 'main' into fix/main-enterprise-api-error-handling
This commit is contained in:
commit
de72bdef71
@ -7,7 +7,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
2
.vscode/launch.json.template
vendored
2
.vscode/launch.json.template
vendored
@ -37,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
5
Makefile
5
Makefile
@ -68,8 +68,9 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + mypy)..."
|
||||
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@./dev/pyrefly-check-local
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
@ -131,7 +132,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy)"
|
||||
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@ -28,17 +28,8 @@ ignore_imports =
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph_events
|
||||
|
||||
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.loop.loop_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph
|
||||
dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine.command_channels
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph
|
||||
dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine.command_channels
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities
|
||||
|
||||
@ -101,12 +92,9 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
dify_graph.nodes.loop.loop_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.agent.agent_node -> core.model_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.provider_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.iteration.iteration_node -> core.workflow.node_factory
|
||||
dify_graph.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
dify_graph.nodes.llm.protocols -> core.model_manager
|
||||
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@ -151,7 +139,6 @@ ignore_imports =
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
dify_graph.nodes.tool.tool_node -> services
|
||||
|
||||
@ -62,6 +62,22 @@ This is the default standard for backend code in this repo. Follow it for new co
|
||||
|
||||
- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
|
||||
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
|
||||
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
|
||||
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
|
||||
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
|
||||
class UserProfile(TypedDict):
|
||||
user_id: str
|
||||
email: str
|
||||
created_at: datetime
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
|
||||
@ -2668,3 +2668,77 @@ def clean_expired_messages(
|
||||
raise
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
|
||||
@click.option("--app-id", required=True, help="Application ID to export messages for.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--filename",
|
||||
required=True,
|
||||
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
|
||||
)
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
def export_app_messages(
|
||||
app_id: str,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
use_cloud_storage: bool,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if start_from and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be before --end-before.")
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
try:
|
||||
validated_filename = AppMessageExportService.validate_export_filename(filename)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(str(e), param_hint="--filename") from e
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=validated_filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"export_app_messages: completed in {elapsed:.2f}s\n"
|
||||
f" - Batches: {stats.batches}\n"
|
||||
f" - Total messages: {stats.total_messages}\n"
|
||||
f" - Messages with feedback: {stats.messages_with_feedback}\n"
|
||||
f" - Total feedbacks: {stats.total_feedbacks}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
logger.exception("export_app_messages failed")
|
||||
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
|
||||
raise
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@ -185,4 +185,4 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
def delete(self, app_model: App, annotation_id: str):
|
||||
"""Delete an annotation."""
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
@ -33,14 +35,14 @@ class AppParameterApi(Resource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
ConversationDelete,
|
||||
ConversationInfiniteScrollPagination,
|
||||
SimpleConversation,
|
||||
)
|
||||
@ -163,7 +162,7 @@ class ConversationDetailApi(Resource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return ConversationDelete(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
||||
@ -132,6 +132,8 @@ class WorkflowRunDetailApi(Resource):
|
||||
app_id=app_model.id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found.")
|
||||
return workflow_run
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
|
||||
def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None:
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
config=sensitive_word_avoidance_dict.get("config", {}),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
|
||||
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
|
||||
class AgentConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> AgentEntity | None:
|
||||
def convert(cls, config: AppModelConfigDict) -> AgentEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -28,17 +31,17 @@ class AgentConfigManager:
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
tool_dict = cast(dict[str, Any], tool)
|
||||
if len(tool_dict) >= 4:
|
||||
if "enabled" not in tool_dict or not tool_dict["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": tool_dict["provider_type"],
|
||||
"provider_id": tool_dict["provider_id"],
|
||||
"tool_name": tool_dict["tool_name"],
|
||||
"tool_parameters": tool_dict.get("tool_parameters", {}),
|
||||
"credential_id": tool_dict.get("credential_id", None),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
|
||||
@ -47,7 +50,8 @@ class AgentConfigManager:
|
||||
"react_router",
|
||||
"router",
|
||||
}:
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
agent_prompt_raw = agent_dict.get("prompt", None)
|
||||
agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {}
|
||||
# check model mode
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
@ -75,7 +79,7 @@ class AgentConfigManager:
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get("max_iteration", 10),
|
||||
max_iteration=cast(int, agent_dict.get("max_iteration", 10)),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import Literal, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -8,13 +8,13 @@ from core.app.app_config.entities import (
|
||||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class DatasetConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> DatasetEntity | None:
|
||||
def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -25,11 +25,15 @@ class DatasetConfigManager:
|
||||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get("datasets", []):
|
||||
if not isinstance(dataset, dict):
|
||||
continue
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset["dataset"]
|
||||
if not isinstance(dataset, dict):
|
||||
continue
|
||||
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
@ -47,15 +51,14 @@ class DatasetConfigManager:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
if len(tool) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
tool_item = cast(dict[str, Any], tool)[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
@ -5,12 +5,13 @@ from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.model import AppModelConfigDict
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> ModelConfigEntity:
|
||||
def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -22,7 +23,7 @@ class ModelConfigManager:
|
||||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get("completion_params")
|
||||
completion_params = model_config.get("completion_params") or {}
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
@ -6,12 +8,12 @@ from core.app.app_config.entities import (
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
|
||||
class PromptTemplateConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> PromptTemplateEntity:
|
||||
def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity:
|
||||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
@ -40,14 +42,15 @@ class PromptTemplateConfigManager:
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
completion_prompt_template_params: dict[str, Any] = {
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
conv_role = completion_prompt_config.get("conversation_histories_role")
|
||||
if conv_role:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
"user": conv_role["user_prefix"],
|
||||
"assistant": conv_role["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
|
||||
class BasicVariablesConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@ -51,7 +53,9 @@ class BasicVariablesConfigManager:
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
variable=variable["variable"],
|
||||
type=variable.get("type", ""),
|
||||
config=variable.get("config", {}),
|
||||
)
|
||||
)
|
||||
elif variable_type in {
|
||||
@ -64,10 +68,10 @@ class BasicVariablesConfigManager:
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
type=cast(VariableEntityType, variable_type),
|
||||
variable=variable["variable"],
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
label=variable["label"],
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
|
||||
@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig):
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict
|
||||
app_model_config_dict: dict[str, Any]
|
||||
model: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
dataset: DatasetEntity | None = None
|
||||
|
||||
@ -516,8 +516,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
yield from self._handle_advanced_chat_message_end_event(
|
||||
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_partial_success_event(
|
||||
self,
|
||||
@ -538,6 +540,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield from self._handle_advanced_chat_message_end_event(
|
||||
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
@ -854,6 +859,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueWorkflowSucceededEvent():
|
||||
yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueWorkflowPartialSuccessEvent():
|
||||
yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||
|
||||
@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
override_config_dict: AppModelConfigDict | None = None,
|
||||
) -> AgentChatAppConfig:
|
||||
"""
|
||||
Convert app model config to agent chat app config
|
||||
@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict or {}
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
config_dict = override_config_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AgentChatAppConfig(
|
||||
@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for agent chat app model config
|
||||
|
||||
@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
||||
@classmethod
|
||||
def validate_agent_mode_and_set_defaults(
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
|
||||
class ChatAppConfig(EasyUIBasedAppConfig):
|
||||
@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
override_config_dict: AppModelConfigDict | None = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for chat app model config
|
||||
|
||||
@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
||||
@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner):
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict
|
||||
|
||||
|
||||
class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None
|
||||
) -> CompletionAppConfig:
|
||||
"""
|
||||
Convert app model config to completion app config
|
||||
@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict or {}
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
config_dict = override_config_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = CompletionAppConfig(
|
||||
@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for completion app model config
|
||||
|
||||
@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
||||
@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
raise ValueError("Message app_model_config is None")
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params")
|
||||
completion_params = model_dict.get("completion_params", {})
|
||||
completion_params["temperature"] = 0.9
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner):
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
@ -8,12 +8,14 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
UserFrom,
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.graph_init_params import GraphInitParams
|
||||
from dify_graph.enums import UserFrom, WorkflowType
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
@ -256,13 +258,15 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
# init graph
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -33,7 +33,6 @@ from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import UserFrom
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import (
|
||||
@ -119,13 +118,15 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
@ -267,13 +268,15 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id="",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
@ -6,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from dify_graph.enums import InvokeFrom
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.file import File, FileUploadConfig
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
@ -14,6 +15,69 @@ if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
class UserFrom(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class DifyRunContext(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
|
||||
def build_dify_run_context(
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
extra_context: Mapping[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build graph run_context with the reserved Dify runtime payload.
|
||||
|
||||
`extra_context` can carry user-defined context keys. The reserved `_dify`
|
||||
payload is always overwritten by this function to keep one canonical source.
|
||||
"""
|
||||
run_context = dict(extra_context) if extra_context else {}
|
||||
run_context[DIFY_RUN_CONTEXT_KEY] = DifyRunContext(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
return run_context
|
||||
|
||||
|
||||
class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
"""
|
||||
Model Config With Credentials Entity.
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Union, cast
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -44,14 +44,13 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.app.task_pipeline.message_file_utils import prepare_file_dict
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
@ -219,14 +218,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
|
||||
text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech"))
|
||||
if (
|
||||
text_to_speech_dict
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
|
||||
tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None)
|
||||
)
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
|
||||
# Fetch files associated with this message
|
||||
files = None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
|
||||
if message_files:
|
||||
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
|
||||
upload_file_ids = list(
|
||||
dict.fromkeys(
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
)
|
||||
)
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
files_list = []
|
||||
for message_file in message_files:
|
||||
file_dict = prepare_file_dict(message_file, upload_files_map)
|
||||
files_list.append(file_dict)
|
||||
|
||||
files = files_list or None
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=metadata_dict,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def _record_files(self):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
if not message_files:
|
||||
return None
|
||||
|
||||
files_list = []
|
||||
upload_file_ids = [
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
]
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
for message_file in message_files:
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
# Fallback: generate URL even if upload_file not found
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
# For tool files, use URL directly if it's HTTP, otherwise sign it
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
else:
|
||||
# Extract tool file id and extension from URL
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0] # Remove query params first
|
||||
# Use rsplit to correctly handle filenames with multiple dots
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
file_dict = {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
files_list.append(file_dict)
|
||||
return files_list or None
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
|
||||
76
api/core/app/task_pipeline/message_file_utils.py
Normal file
76
api/core/app/task_pipeline/message_file_utils.py
Normal file
@ -0,0 +1,76 @@
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
:param message_file: MessageFile instance
|
||||
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
|
||||
:return: Dictionary containing file information
|
||||
"""
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method.value
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
return {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
@ -75,8 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
return
|
||||
|
||||
try:
|
||||
dify_ctx = node.require_dify_context()
|
||||
deduct_llm_quota(
|
||||
tenant_id=node.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
@ -7,7 +7,7 @@ import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Final, cast
|
||||
from typing import Final
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
@ -201,7 +201,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int:
|
||||
raise ValueError("UUID cannot be None")
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return cast(int, uuid_obj.int)
|
||||
return uuid_obj.int
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import hashlib
|
||||
import random
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
@ -23,7 +22,7 @@ class TencentTraceUtils:
|
||||
uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
return cast(int, uuid_obj.int)
|
||||
return uuid_obj.int
|
||||
|
||||
@staticmethod
|
||||
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
|
||||
@ -52,9 +51,9 @@ class TencentTraceUtils:
|
||||
@staticmethod
|
||||
def create_link(trace_id_str: str) -> Link:
|
||||
try:
|
||||
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int)
|
||||
trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else uuid.UUID(trace_id_str).int
|
||||
except (ValueError, TypeError):
|
||||
trace_id = cast(int, uuid.uuid4().int)
|
||||
trace_id = uuid.uuid4().int
|
||||
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ class ChromaVector(BaseVector):
|
||||
self._client.get_or_create_collection(collection_name)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
@ -73,6 +73,7 @@ class ChromaVector(BaseVector):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
# FIXME: chromadb using numpy array, fix the type error later
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
|
||||
return uuids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
|
||||
@ -605,25 +605,36 @@ class ClickzettaVector(BaseVector):
|
||||
logger.warning("Failed to create inverted index: %s", e)
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
return
|
||||
return []
|
||||
|
||||
batch_size = self._config.batch_size
|
||||
total_batches = (len(documents) + batch_size - 1) // batch_size
|
||||
added_ids = []
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
batch_doc_ids = []
|
||||
for doc in batch_docs:
|
||||
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
||||
batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))))
|
||||
added_ids.extend(batch_doc_ids)
|
||||
|
||||
# Execute batch insert through write queue
|
||||
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
|
||||
self._execute_write(
|
||||
self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches
|
||||
)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _insert_batch(
|
||||
self,
|
||||
batch_docs: list[Document],
|
||||
batch_embeddings: list[list[float]],
|
||||
batch_doc_ids: list[str],
|
||||
batch_index: int,
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
@ -641,14 +652,9 @@ class ClickzettaVector(BaseVector):
|
||||
data_rows = []
|
||||
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
|
||||
|
||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||
for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids):
|
||||
# Optimized: minimal checks for common case, fallback for edge cases
|
||||
metadata = doc.metadata or {}
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
|
||||
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
||||
|
||||
# Fast path for JSON serialization
|
||||
try:
|
||||
|
||||
@ -194,6 +194,13 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
|
||||
# Create a new database session
|
||||
with self._session_factory() as session:
|
||||
existing_model = session.get(WorkflowRun, db_model.id)
|
||||
if existing_model:
|
||||
if existing_model.tenant_id != self._tenant_id:
|
||||
raise ValueError("Unauthorized access to workflow run")
|
||||
# Preserve the original start time for pause/resume flows.
|
||||
db_model.created_at = existing_model.created_at
|
||||
|
||||
# SQLAlchemy merge intelligently handles both insert and update operations
|
||||
# based on the presence of the primary key
|
||||
session.merge(db_model)
|
||||
|
||||
@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import (
|
||||
@ -22,6 +23,7 @@ from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import NodeType, SystemVariableKey
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
from dify_graph.graph.graph import NodeFactory
|
||||
@ -110,6 +112,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._dify_context = self._resolve_dify_context(graph_init_params.run_context)
|
||||
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
|
||||
self._code_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
@ -141,7 +144,16 @@ class DifyNodeFactory(NodeFactory):
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
|
||||
raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY)
|
||||
if raw_ctx is None:
|
||||
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
|
||||
if isinstance(raw_ctx, DifyRunContext):
|
||||
return raw_ctx
|
||||
return DifyRunContext.model_validate(raw_ctx)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@ -213,7 +225,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
form_repository=HumanInputFormRepositoryImpl(tenant_id=self.graph_init_params.tenant_id),
|
||||
form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_INDEX:
|
||||
@ -356,7 +368,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
node_data_memory=node_memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@ -5,26 +5,26 @@ from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict
|
||||
from dify_graph.enums import UserFrom
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_engine.protocols.command_channel import CommandChannel
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
@ -34,6 +34,66 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
@staticmethod
|
||||
def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None:
|
||||
"""
|
||||
Return whether `graph_config["nodes"]` contains the given node id.
|
||||
|
||||
Returns `None` when the nodes payload shape is unexpected, so graph-level
|
||||
validation can surface the original configuration error.
|
||||
"""
|
||||
nodes = graph_config.get("nodes")
|
||||
if not isinstance(nodes, list):
|
||||
return None
|
||||
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
return None
|
||||
current_id = node.get("id")
|
||||
if isinstance(current_id, str) and current_id == node_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def build_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: str,
|
||||
layers: Sequence[object] = (),
|
||||
) -> GraphEngine:
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id)
|
||||
if has_root_node is False:
|
||||
raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found")
|
||||
|
||||
child_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
node_factory=node_factory,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
|
||||
child_engine = GraphEngine(
|
||||
workflow_id=workflow_id,
|
||||
graph=child_graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
child_engine_builder=self,
|
||||
)
|
||||
child_engine.layer(LLMQuotaLayer())
|
||||
for layer in layers:
|
||||
child_engine.layer(cast(GraphEngineLayer, layer))
|
||||
return child_engine
|
||||
|
||||
|
||||
class WorkflowEntry:
|
||||
def __init__(
|
||||
self,
|
||||
@ -77,6 +137,7 @@ class WorkflowEntry:
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
self.command_channel = command_channel
|
||||
self._child_engine_builder = _WorkflowChildEngineBuilder()
|
||||
self.graph_engine = GraphEngine(
|
||||
workflow_id=workflow_id,
|
||||
graph=graph,
|
||||
@ -88,6 +149,7 @@ class WorkflowEntry:
|
||||
scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD,
|
||||
scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME,
|
||||
),
|
||||
child_engine_builder=self._child_engine_builder,
|
||||
)
|
||||
|
||||
# Add debug logging layer when in debug mode
|
||||
@ -154,13 +216,15 @@ class WorkflowEntry:
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
@ -293,13 +357,15 @@ class WorkflowEntry:
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_dict,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dify_graph.enums import InvokeFrom, UserFrom
|
||||
DIFY_RUN_CONTEXT_KEY = "_dify"
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
@ -18,11 +18,7 @@ class GraphInitParams(BaseModel):
|
||||
"""
|
||||
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
run_context: Mapping[str, Any] = Field(..., description="runtime context")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
|
||||
@ -33,39 +33,6 @@ class SystemVariableKey(StrEnum):
|
||||
INVOKE_FROM = "invoke_from"
|
||||
|
||||
|
||||
class UserFrom(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end-user"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
"""Get source of invoke from.
|
||||
|
||||
:return: source
|
||||
"""
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
|
||||
@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from dify_graph.context import capture_current_context
|
||||
@ -27,6 +27,7 @@ from dify_graph.graph_events import (
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
|
||||
from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
||||
from dify_graph.runtime.graph_runtime_state import GraphProtocol
|
||||
@ -49,6 +50,7 @@ from .protocols.command_channel import CommandChannel
|
||||
from .worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.graph_engine.domain.graph_execution import GraphExecution
|
||||
from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
@ -74,6 +76,7 @@ class GraphEngine:
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: CommandChannel,
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
child_engine_builder: ChildGraphEngineBuilderProtocol | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
@ -83,6 +86,9 @@ class GraphEngine:
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
self._config = config
|
||||
self._child_engine_builder = child_engine_builder
|
||||
if child_engine_builder is not None:
|
||||
self._graph_runtime_state.bind_child_engine_builder(child_engine_builder)
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
||||
@ -214,6 +220,25 @@ class GraphEngine:
|
||||
self._bind_layer_context(layer)
|
||||
return self
|
||||
|
||||
def create_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: dict[str, object] | Mapping[str, object],
|
||||
root_node_id: str,
|
||||
layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (),
|
||||
) -> GraphEngine:
|
||||
return self._graph_runtime_state.create_child_engine(
|
||||
workflow_id=workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Execute the graph using the modular architecture.
|
||||
|
||||
@ -80,9 +80,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
@ -120,8 +122,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
@ -144,8 +146,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
@ -283,8 +285,13 @@ class AgentNode(Node[AgentNodeData]):
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
dify_ctx = self.require_dify_context()
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
entity,
|
||||
dify_ctx.invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
@ -396,7 +403,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
dify_ctx = self.require_dify_context()
|
||||
plugins = manager.list_plugins(dify_ctx.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
@ -417,8 +425,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
@ -429,9 +440,10 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
@ -440,7 +452,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
|
||||
@ -8,10 +8,11 @@ from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
@ -64,10 +65,28 @@ from libs.datetime_utils import naive_utc_now
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyRunContextProtocol(Protocol):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: Any
|
||||
invoke_from: Any
|
||||
|
||||
|
||||
class _MappingDifyRunContext:
|
||||
def __init__(self, mapping: Mapping[str, Any]) -> None:
|
||||
self.tenant_id = str(mapping["tenant_id"])
|
||||
self.app_id = str(mapping["app_id"])
|
||||
self.user_id = str(mapping["user_id"])
|
||||
self.user_from = mapping["user_from"]
|
||||
self.invoke_from = mapping["invoke_from"]
|
||||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
"""BaseNode serves as the foundational class for all node implementations.
|
||||
|
||||
@ -227,14 +246,10 @@ class Node(Generic[NodeDataT]):
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self._run_context = MappingProxyType(dict(graph_init_params.run_context))
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
self.workflow_id = graph_init_params.workflow_id
|
||||
self.graph_config = graph_init_params.graph_config
|
||||
self.user_id = graph_init_params.user_id
|
||||
self.user_from = graph_init_params.user_from
|
||||
self.invoke_from = graph_init_params.invoke_from
|
||||
self.workflow_call_depth = graph_init_params.call_depth
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
@ -263,6 +278,38 @@ class Node(Generic[NodeDataT]):
|
||||
def graph_init_params(self) -> GraphInitParams:
|
||||
return self._graph_init_params
|
||||
|
||||
@property
|
||||
def run_context(self) -> Mapping[str, Any]:
|
||||
return self._run_context
|
||||
|
||||
def get_run_context_value(self, key: str, default: Any = None) -> Any:
|
||||
return self._run_context.get(key, default)
|
||||
|
||||
def require_run_context_value(self, key: str) -> Any:
|
||||
value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE)
|
||||
if value is _MISSING_RUN_CONTEXT_VALUE:
|
||||
raise ValueError(f"run_context missing required key: {key}")
|
||||
return value
|
||||
|
||||
def require_dify_context(self) -> DifyRunContextProtocol:
|
||||
raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)
|
||||
if raw_ctx is None:
|
||||
raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}")
|
||||
|
||||
if isinstance(raw_ctx, Mapping):
|
||||
missing_keys = [
|
||||
key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx
|
||||
]
|
||||
if missing_keys:
|
||||
raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}")
|
||||
return _MappingDifyRunContext(raw_ctx)
|
||||
|
||||
for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"):
|
||||
if not hasattr(raw_ctx, attr):
|
||||
raise TypeError(f"invalid dify context object, missing attribute: {attr}")
|
||||
|
||||
return cast(DifyRunContextProtocol, raw_ctx)
|
||||
|
||||
@property
|
||||
def execution_id(self) -> str:
|
||||
return self._node_execution_id
|
||||
|
||||
@ -52,6 +52,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
Run the datasource node
|
||||
"""
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
@ -75,7 +76,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
datasource_info["icon"] = self.datasource_manager.get_icon_url(
|
||||
provider_id=provider_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
datasource_type=datasource_type.value,
|
||||
)
|
||||
|
||||
@ -104,11 +105,11 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
yield from self.datasource_manager.stream_node_events(
|
||||
node_id=self._node_id,
|
||||
user_id=self.user_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
datasource_type=datasource_type.value,
|
||||
provider_id=provider_id,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=credential_id,
|
||||
@ -136,7 +137,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
|
||||
file_info = self.datasource_manager.get_upload_file_by_id(
|
||||
file_id=related_id, tenant_id=self.tenant_id
|
||||
file_id=related_id, tenant_id=dify_ctx.tenant_id
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
|
||||
@ -212,6 +212,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
"""
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
dify_ctx = self.require_dify_context()
|
||||
files: list[File] = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
@ -236,8 +237,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
tool_file_manager = self._tool_file_manager_factory()
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=mime_type,
|
||||
@ -249,7 +250,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
@ -31,6 +31,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
_INVOKE_FROM_DEBUGGER = "debugger"
|
||||
_INVOKE_FROM_EXPLORE = "explore"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -155,30 +157,39 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
return resolved_defaults
|
||||
|
||||
def _should_require_console_recipient(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
invoke_from = self._invoke_from_value()
|
||||
if invoke_from == _INVOKE_FROM_DEBUGGER:
|
||||
return True
|
||||
if self.invoke_from == InvokeFrom.EXPLORE:
|
||||
if invoke_from == _INVOKE_FROM_EXPLORE:
|
||||
return self._node_data.is_webapp_enabled()
|
||||
return False
|
||||
|
||||
def _display_in_ui(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER:
|
||||
return True
|
||||
return self._node_data.is_webapp_enabled()
|
||||
|
||||
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
invoke_from = self._invoke_from_value()
|
||||
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
|
||||
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}:
|
||||
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
|
||||
return [
|
||||
apply_debug_email_recipient(
|
||||
method,
|
||||
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
|
||||
user_id=self.user_id or "",
|
||||
enabled=invoke_from == _INVOKE_FROM_DEBUGGER,
|
||||
user_id=dify_ctx.user_id,
|
||||
)
|
||||
for method in enabled_methods
|
||||
]
|
||||
|
||||
def _invoke_from_value(self) -> str:
|
||||
invoke_from = self.require_dify_context().invoke_from
|
||||
if isinstance(invoke_from, str):
|
||||
return invoke_from
|
||||
return str(getattr(invoke_from, "value", invoke_from))
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
resolved_default_values = self.resolve_default_values()
|
||||
@ -212,10 +223,11 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"""
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
dify_ctx = self.require_dify_context()
|
||||
if form is None:
|
||||
display_in_ui = self._display_in_ui()
|
||||
params = FormCreateParams(
|
||||
app_id=self.app_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
@ -225,7 +237,9 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
resolved_default_values=self.resolve_default_values(),
|
||||
console_recipient_required=self._should_require_console_recipient(),
|
||||
console_creator_account_id=(
|
||||
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
|
||||
dify_ctx.user_id
|
||||
if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}
|
||||
else None
|
||||
),
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
|
||||
@ -587,24 +587,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
return
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
# Create GraphInitParams for child graph execution.
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
run_context=self.run_context,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
# Create a deep copy of the variable pool for each iteration
|
||||
@ -621,28 +611,17 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
root_node_id = self.node_data.start_node_id
|
||||
if root_node_id is None:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the iteration graph with the new node factory
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=iteration_graph,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
try:
|
||||
return self.graph_runtime_state.create_child_engine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
graph_config=self.graph_config,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
except ChildGraphNotFoundError as exc:
|
||||
raise IterationGraphNotFoundError("iteration graph not found") from exc
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import InvokeFrom, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.template import Template
|
||||
@ -20,6 +20,7 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_INVOKE_FROM_DEBUGGER = "debugger"
|
||||
|
||||
|
||||
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
@ -58,7 +59,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
if not variable:
|
||||
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from.value == InvokeFrom.DEBUGGER if invoke_from else False
|
||||
invoke_from_value = str(invoke_from.value) if invoke_from else None
|
||||
is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER
|
||||
|
||||
chunks = variable.value
|
||||
variables = {"chunks": chunks}
|
||||
|
||||
@ -66,9 +66,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@ -115,7 +116,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@ -160,6 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||
) -> tuple[list[Source], LLMUsage]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
dataset_ids = node_data.dataset_ids
|
||||
query = variables.get("query")
|
||||
attachments = variables.get("attachments")
|
||||
@ -176,10 +178,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
model = node_data.single_retrieval_config.model
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
user_from=dify_ctx.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
|
||||
completion_params=model.completion_params,
|
||||
@ -229,10 +231,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
app_id=dify_ctx.app_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
user_from=dify_ctx.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
query=query,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
|
||||
|
||||
@ -145,9 +145,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@ -242,7 +243,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=self.node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
@ -702,7 +703,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=self.require_dify_context().tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
|
||||
@ -412,24 +412,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from dify_graph.graph_engine.command_channels import InMemoryChannel
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
# Create GraphInitParams from node attributes
|
||||
# Create GraphInitParams for child graph execution.
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
run_context=self.run_context,
|
||||
call_depth=self.workflow_call_depth,
|
||||
)
|
||||
|
||||
@ -439,22 +429,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
start_at=start_at.timestamp(),
|
||||
)
|
||||
|
||||
# Create a new node factory with the new GraphRuntimeState
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy
|
||||
)
|
||||
|
||||
# Initialize the loop graph with the new node factory
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
# Create a new GraphEngine for this iteration
|
||||
graph_engine = GraphEngine(
|
||||
return self.graph_runtime_state.create_child_engine(
|
||||
workflow_id=self.workflow_id,
|
||||
graph=loop_graph,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state_copy,
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
graph_config=self.graph_config,
|
||||
root_node_id=root_node_id,
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@ -297,7 +297,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
tools=tools,
|
||||
stop=list(stop),
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
user=self.require_dify_context().user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
||||
@ -86,9 +86,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self._memory = memory
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@ -160,7 +161,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
user_id=self.require_dify_context().user_id,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
|
||||
@ -56,6 +56,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
"provider_type": self.node_data.provider_type.value,
|
||||
@ -75,7 +77,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
self._node_id,
|
||||
self.node_data,
|
||||
dify_ctx.invoke_from,
|
||||
variable_pool,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
@ -109,10 +116,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.app_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
@ -133,8 +140,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_id=self._node_id,
|
||||
tool_runtime=tool_runtime,
|
||||
)
|
||||
|
||||
@ -69,6 +69,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
)
|
||||
|
||||
def generate_file_var(self, param_name: str, file: dict):
|
||||
dify_ctx = self.require_dify_context()
|
||||
related_id = file.get("related_id")
|
||||
transfer_method_value = file.get("transfer_method")
|
||||
if transfer_method_value:
|
||||
@ -84,7 +85,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
try:
|
||||
file_obj = file_factory.build_from_mapping(
|
||||
mapping=file,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
file_segment = build_segment_with_type(SegmentType.FILE, file_obj)
|
||||
return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name])
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
from .graph_runtime_state import GraphRuntimeState
|
||||
from .graph_runtime_state import (
|
||||
ChildEngineBuilderNotConfiguredError,
|
||||
ChildEngineError,
|
||||
ChildGraphNotFoundError,
|
||||
GraphRuntimeState,
|
||||
)
|
||||
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
|
||||
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
|
||||
from .variable_pool import VariablePool, VariableValue
|
||||
|
||||
__all__ = [
|
||||
"ChildEngineBuilderNotConfiguredError",
|
||||
"ChildEngineError",
|
||||
"ChildGraphNotFoundError",
|
||||
"GraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeState",
|
||||
"ReadOnlyGraphRuntimeStateWrapper",
|
||||
|
||||
@ -15,6 +15,7 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
@ -135,6 +136,31 @@ class GraphProtocol(Protocol):
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
class ChildGraphEngineBuilderProtocol(Protocol):
|
||||
def build_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: str,
|
||||
layers: Sequence[object] = (),
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class ChildEngineError(ValueError):
|
||||
"""Base error type for child-engine creation failures."""
|
||||
|
||||
|
||||
class ChildEngineBuilderNotConfiguredError(ChildEngineError):
|
||||
"""Raised when child-engine creation is requested without a bound builder."""
|
||||
|
||||
|
||||
class ChildGraphNotFoundError(ChildEngineError):
|
||||
"""Raised when the requested child graph entry point cannot be resolved."""
|
||||
|
||||
|
||||
class _GraphStateSnapshot(BaseModel):
|
||||
"""Serializable graph state snapshot for node/edge states."""
|
||||
|
||||
@ -209,6 +235,7 @@ class GraphRuntimeState:
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self._deferred_nodes: set[str] = set()
|
||||
self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None
|
||||
|
||||
# Node and edges states needed to be restored into
|
||||
# graph object.
|
||||
@ -250,6 +277,31 @@ class GraphRuntimeState:
|
||||
if self._graph is not None:
|
||||
_ = self.response_coordinator
|
||||
|
||||
def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None:
|
||||
self._child_engine_builder = builder
|
||||
|
||||
def create_child_engine(
|
||||
self,
|
||||
*,
|
||||
workflow_id: str,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: str,
|
||||
layers: Sequence[object] = (),
|
||||
) -> Any:
|
||||
if self._child_engine_builder is None:
|
||||
raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.")
|
||||
|
||||
return self._child_engine_builder.build_child_engine(
|
||||
workflow_id=workflow_id,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Primary collaborators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@ -65,9 +65,15 @@ class VariablePool(BaseModel):
|
||||
# Add environment variables to the variable pool
|
||||
for var in self.environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add conversation variables to the variable pool
|
||||
# Add conversation variables to the variable pool. When restoring from a serialized
|
||||
# snapshot, `variable_dictionary` already carries the latest runtime values.
|
||||
# In that case, keep existing entries instead of overwriting them with the
|
||||
# bootstrap list.
|
||||
for var in self.conversation_variables:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
selector = (CONVERSATION_VARIABLE_NODE_ID, var.name)
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
||||
|
||||
@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from events.app_event import app_model_config_was_updated
|
||||
@ -54,9 +56,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[s
|
||||
continue
|
||||
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = list(tool.values())[0]
|
||||
tool_config = cast(dict[str, Any], list(tool.values())[0])
|
||||
if tool_type == "dataset":
|
||||
dataset_ids.add(tool_config.get("id"))
|
||||
dataset_id = tool_config.get("id")
|
||||
if isinstance(dataset_id, str):
|
||||
dataset_ids.add(dataset_id)
|
||||
|
||||
# get dataset from dataset_configs
|
||||
dataset_configs = app_model_config.dataset_configs_dict
|
||||
|
||||
@ -13,6 +13,7 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
@ -66,6 +67,7 @@ def init_app(app: DifyApp):
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
export_app_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@ -66,6 +66,7 @@ def run_migrations_offline():
|
||||
context.configure(
|
||||
url=url, target_metadata=get_metadata(), literal_binds=True
|
||||
)
|
||||
logger.info("Generating offline migration SQL with url: %s", url)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
"""add partial indexes on conversations for app_id with created_at and updated_at
|
||||
|
||||
Revision ID: e288952f2994
|
||||
Revises: fce013ca180e
|
||||
Create Date: 2026-02-26 13:36:45.928922
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e288952f2994'
|
||||
down_revision = 'fce013ca180e'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
'conversation_app_created_at_idx',
|
||||
['app_id', sa.literal_column('created_at DESC')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('is_deleted IS false'),
|
||||
)
|
||||
batch_op.create_index(
|
||||
'conversation_app_updated_at_idx',
|
||||
['app_id', sa.literal_column('updated_at DESC')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('is_deleted IS false'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table('conversations', schema=None) as batch_op:
|
||||
batch_op.drop_index('conversation_app_updated_at_idx')
|
||||
batch_op.drop_index('conversation_app_created_at_idx')
|
||||
@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -15,6 +15,7 @@ from flask import request
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
@ -36,6 +37,259 @@ if TYPE_CHECKING:
|
||||
from .workflow import Workflow
|
||||
|
||||
|
||||
# --- TypedDict definitions for structured dict return types ---
|
||||
|
||||
|
||||
class EnabledConfig(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EmbeddingModelInfo(TypedDict):
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class AnnotationReplyDisabledConfig(TypedDict):
|
||||
enabled: Literal[False]
|
||||
|
||||
|
||||
class AnnotationReplyEnabledConfig(TypedDict):
|
||||
id: str
|
||||
enabled: Literal[True]
|
||||
score_threshold: float
|
||||
embedding_model: EmbeddingModelInfo
|
||||
|
||||
|
||||
AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfig(TypedDict):
|
||||
enabled: bool
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class AgentToolConfig(TypedDict):
|
||||
provider_type: str
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any]
|
||||
plugin_unique_identifier: NotRequired[str | None]
|
||||
credential_id: NotRequired[str | None]
|
||||
|
||||
|
||||
class AgentModeConfig(TypedDict):
|
||||
enabled: bool
|
||||
strategy: str | None
|
||||
tools: list[AgentToolConfig | dict[str, Any]]
|
||||
prompt: str | None
|
||||
|
||||
|
||||
class ImageUploadConfig(TypedDict):
|
||||
enabled: bool
|
||||
number_limits: int
|
||||
detail: str
|
||||
transfer_methods: list[str]
|
||||
|
||||
|
||||
class FileUploadConfig(TypedDict):
|
||||
image: ImageUploadConfig
|
||||
|
||||
|
||||
class DeletedToolInfo(TypedDict):
|
||||
type: str
|
||||
tool_name: str
|
||||
provider_id: str
|
||||
|
||||
|
||||
class ExternalDataToolConfig(TypedDict):
|
||||
enabled: bool
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class UserInputFormItemConfig(TypedDict):
|
||||
variable: str
|
||||
label: str
|
||||
description: NotRequired[str]
|
||||
required: NotRequired[bool]
|
||||
max_length: NotRequired[int]
|
||||
options: NotRequired[list[str]]
|
||||
default: NotRequired[str]
|
||||
type: NotRequired[str]
|
||||
config: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig}
|
||||
UserInputFormItem = dict[str, UserInputFormItemConfig]
|
||||
|
||||
|
||||
class DatasetConfigs(TypedDict):
|
||||
retrieval_model: str
|
||||
datasets: NotRequired[dict[str, Any]]
|
||||
top_k: NotRequired[int]
|
||||
score_threshold: NotRequired[float]
|
||||
score_threshold_enabled: NotRequired[bool]
|
||||
reranking_model: NotRequired[dict[str, Any] | None]
|
||||
weights: NotRequired[dict[str, Any] | None]
|
||||
reranking_enabled: NotRequired[bool]
|
||||
reranking_mode: NotRequired[str]
|
||||
metadata_filtering_mode: NotRequired[str]
|
||||
metadata_model_config: NotRequired[dict[str, Any] | None]
|
||||
metadata_filtering_conditions: NotRequired[dict[str, Any] | None]
|
||||
|
||||
|
||||
class ChatPromptMessage(TypedDict):
|
||||
text: str
|
||||
role: str
|
||||
|
||||
|
||||
class ChatPromptConfig(TypedDict, total=False):
|
||||
prompt: list[ChatPromptMessage]
|
||||
|
||||
|
||||
class CompletionPromptText(TypedDict):
|
||||
text: str
|
||||
|
||||
|
||||
class ConversationHistoriesRole(TypedDict):
|
||||
user_prefix: str
|
||||
assistant_prefix: str
|
||||
|
||||
|
||||
class CompletionPromptConfig(TypedDict):
|
||||
prompt: CompletionPromptText
|
||||
conversation_histories_role: NotRequired[ConversationHistoriesRole]
|
||||
|
||||
|
||||
class ModelConfig(TypedDict):
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class AppModelConfigDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: EnabledConfig
|
||||
speech_to_text: EnabledConfig
|
||||
text_to_speech: EnabledConfig
|
||||
retriever_resource: EnabledConfig
|
||||
annotation_reply: AnnotationReplyConfig
|
||||
more_like_this: EnabledConfig
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceConfig
|
||||
external_data_tools: list[ExternalDataToolConfig]
|
||||
model: ModelConfig
|
||||
user_input_form: list[UserInputFormItem]
|
||||
dataset_query_variable: str | None
|
||||
pre_prompt: str | None
|
||||
agent_mode: AgentModeConfig
|
||||
prompt_type: str
|
||||
chat_prompt_config: ChatPromptConfig
|
||||
completion_prompt_config: CompletionPromptConfig
|
||||
dataset_configs: DatasetConfigs
|
||||
file_upload: FileUploadConfig
|
||||
# Added dynamically in Conversation.model_config
|
||||
model_id: NotRequired[str | None]
|
||||
provider: NotRequired[str | None]
|
||||
|
||||
|
||||
class ConversationDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
app_model_config_id: str | None
|
||||
model_provider: str | None
|
||||
override_model_configs: str | None
|
||||
model_id: str | None
|
||||
mode: str
|
||||
name: str
|
||||
summary: str | None
|
||||
inputs: dict[str, Any]
|
||||
introduction: str | None
|
||||
system_instruction: str | None
|
||||
system_instruction_tokens: int
|
||||
status: str
|
||||
invoke_from: str | None
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
read_at: datetime | None
|
||||
read_account_id: str | None
|
||||
dialogue_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MessageDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
model_id: str | None
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
total_price: Decimal | None
|
||||
message: dict[str, Any]
|
||||
answer: str
|
||||
status: str
|
||||
error: str | None
|
||||
message_metadata: dict[str, Any]
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
agent_based: bool
|
||||
workflow_run_id: str | None
|
||||
|
||||
|
||||
class MessageFeedbackDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class MessageFileInfo(TypedDict, total=False):
|
||||
belongs_to: str | None
|
||||
upload_file_id: str | None
|
||||
id: str
|
||||
tenant_id: str
|
||||
type: str
|
||||
transfer_method: str
|
||||
remote_url: str | None
|
||||
related_id: str | None
|
||||
filename: str | None
|
||||
extension: str | None
|
||||
mime_type: str | None
|
||||
size: int
|
||||
dify_model_identity: str
|
||||
url: str | None
|
||||
|
||||
|
||||
class ExtraContentDict(TypedDict, total=False):
|
||||
type: str
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class TraceAppConfigDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
tracing_provider: str | None
|
||||
tracing_config: dict[str, Any]
|
||||
is_active: bool
|
||||
created_at: str | None
|
||||
updated_at: str | None
|
||||
|
||||
|
||||
class DifySetup(TypeBase):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
@ -176,7 +430,7 @@ class App(Base):
|
||||
return str(self.mode)
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list[dict[str, str]]:
|
||||
def deleted_tools(self) -> list[DeletedToolInfo]:
|
||||
from core.tools.tool_manager import ToolManager, ToolProviderType
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
@ -257,7 +511,7 @@ class App(Base):
|
||||
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
|
||||
}
|
||||
|
||||
deleted_tools: list[dict[str, str]] = []
|
||||
deleted_tools: list[DeletedToolInfo] = []
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
@ -364,35 +618,38 @@ class AppModelConfig(TypeBase):
|
||||
return app
|
||||
|
||||
@property
|
||||
def model_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.model) if self.model else {}
|
||||
def model_dict(self) -> ModelConfig:
|
||||
return cast(ModelConfig, json.loads(self.model) if self.model else {})
|
||||
|
||||
@property
|
||||
def suggested_questions_list(self) -> list[str]:
|
||||
return json.loads(self.suggested_questions) if self.suggested_questions else []
|
||||
|
||||
@property
|
||||
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig,
|
||||
json.loads(self.suggested_questions_after_answer)
|
||||
if self.suggested_questions_after_answer
|
||||
else {"enabled": False}
|
||||
else {"enabled": False},
|
||||
)
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
|
||||
def speech_to_text_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
|
||||
|
||||
@property
|
||||
def text_to_speech_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
|
||||
def text_to_speech_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
|
||||
|
||||
@property
|
||||
def retriever_resource_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
def retriever_resource_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
)
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> dict[str, Any]:
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
)
|
||||
@ -415,56 +672,62 @@ class AppModelConfig(TypeBase):
|
||||
return {"enabled": False}
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
|
||||
return cast(
|
||||
SensitiveWordAvoidanceConfig,
|
||||
json.loads(self.sensitive_word_avoidance)
|
||||
if self.sensitive_word_avoidance
|
||||
else {"enabled": False, "type": "", "configs": []}
|
||||
else {"enabled": False, "type": "", "config": {}},
|
||||
)
|
||||
|
||||
@property
|
||||
def external_data_tools_list(self) -> list[dict[str, Any]]:
|
||||
def external_data_tools_list(self) -> list[ExternalDataToolConfig]:
|
||||
return json.loads(self.external_data_tools) if self.external_data_tools else []
|
||||
|
||||
@property
|
||||
def user_input_form_list(self) -> list[dict[str, Any]]:
|
||||
def user_input_form_list(self) -> list[UserInputFormItem]:
|
||||
return json.loads(self.user_input_form) if self.user_input_form else []
|
||||
|
||||
@property
|
||||
def agent_mode_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def agent_mode_dict(self) -> AgentModeConfig:
|
||||
return cast(
|
||||
AgentModeConfig,
|
||||
json.loads(self.agent_mode)
|
||||
if self.agent_mode
|
||||
else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
|
||||
else {"enabled": False, "strategy": None, "tools": [], "prompt": None},
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
|
||||
def chat_prompt_config_dict(self) -> ChatPromptConfig:
|
||||
return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {})
|
||||
|
||||
@property
|
||||
def completion_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
|
||||
def completion_prompt_config_dict(self) -> CompletionPromptConfig:
|
||||
return cast(
|
||||
CompletionPromptConfig,
|
||||
json.loads(self.completion_prompt_config) if self.completion_prompt_config else {},
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_configs_dict(self) -> dict[str, Any]:
|
||||
def dataset_configs_dict(self) -> DatasetConfigs:
|
||||
if self.dataset_configs:
|
||||
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
|
||||
dataset_configs = json.loads(self.dataset_configs)
|
||||
if "retrieval_model" not in dataset_configs:
|
||||
return {"retrieval_model": "single"}
|
||||
else:
|
||||
return dataset_configs
|
||||
return cast(DatasetConfigs, dataset_configs)
|
||||
return {
|
||||
"retrieval_model": "multiple",
|
||||
}
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def file_upload_dict(self) -> FileUploadConfig:
|
||||
return cast(
|
||||
FileUploadConfig,
|
||||
json.loads(self.file_upload)
|
||||
if self.file_upload
|
||||
else {
|
||||
@ -474,10 +737,10 @@ class AppModelConfig(TypeBase):
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> AppModelConfigDict:
|
||||
return {
|
||||
"opening_statement": self.opening_statement,
|
||||
"suggested_questions": self.suggested_questions_list,
|
||||
@ -501,36 +764,42 @@ class AppModelConfig(TypeBase):
|
||||
"file_upload": self.file_upload_dict,
|
||||
}
|
||||
|
||||
def from_model_config_dict(self, model_config: Mapping[str, Any]):
|
||||
def from_model_config_dict(self, model_config: AppModelConfigDict):
|
||||
self.opening_statement = model_config.get("opening_statement")
|
||||
self.suggested_questions = (
|
||||
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
|
||||
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
|
||||
)
|
||||
self.suggested_questions_after_answer = (
|
||||
json.dumps(model_config["suggested_questions_after_answer"])
|
||||
json.dumps(model_config.get("suggested_questions_after_answer"))
|
||||
if model_config.get("suggested_questions_after_answer")
|
||||
else None
|
||||
)
|
||||
self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None
|
||||
self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None
|
||||
self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None
|
||||
self.speech_to_text = (
|
||||
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
|
||||
)
|
||||
self.text_to_speech = (
|
||||
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
|
||||
)
|
||||
self.more_like_this = (
|
||||
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
|
||||
)
|
||||
self.sensitive_word_avoidance = (
|
||||
json.dumps(model_config["sensitive_word_avoidance"])
|
||||
json.dumps(model_config.get("sensitive_word_avoidance"))
|
||||
if model_config.get("sensitive_word_avoidance")
|
||||
else None
|
||||
)
|
||||
self.external_data_tools = (
|
||||
json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None
|
||||
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
|
||||
)
|
||||
self.model = json.dumps(model_config["model"]) if model_config.get("model") else None
|
||||
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
|
||||
self.user_input_form = (
|
||||
json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None
|
||||
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
|
||||
)
|
||||
self.dataset_query_variable = model_config.get("dataset_query_variable")
|
||||
self.pre_prompt = model_config["pre_prompt"]
|
||||
self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None
|
||||
self.pre_prompt = model_config.get("pre_prompt")
|
||||
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
|
||||
self.retriever_resource = (
|
||||
json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None
|
||||
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
|
||||
)
|
||||
self.prompt_type = model_config.get("prompt_type", "simple")
|
||||
self.chat_prompt_config = (
|
||||
@ -711,6 +980,18 @@ class Conversation(Base):
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="conversation_pkey"),
|
||||
sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
|
||||
sa.Index(
|
||||
"conversation_app_created_at_idx",
|
||||
"app_id",
|
||||
sa.text("created_at DESC"),
|
||||
postgresql_where=sa.text("is_deleted IS false"),
|
||||
),
|
||||
sa.Index(
|
||||
"conversation_app_updated_at_idx",
|
||||
"app_id",
|
||||
sa.text("updated_at DESC"),
|
||||
postgresql_where=sa.text("is_deleted IS false"),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
@ -811,24 +1092,26 @@ class Conversation(Base):
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
def model_config(self) -> AppModelConfigDict:
|
||||
model_config = cast(AppModelConfigDict, {})
|
||||
app_model_config: AppModelConfig | None = None
|
||||
|
||||
if self.mode == AppMode.ADVANCED_CHAT:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
model_config = override_model_configs
|
||||
model_config = cast(AppModelConfigDict, override_model_configs)
|
||||
else:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
|
||||
if "model" in override_model_configs:
|
||||
# where is app_id?
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(
|
||||
cast(AppModelConfigDict, override_model_configs)
|
||||
)
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
else:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
@ -1003,7 +1286,7 @@ class Conversation(Base):
|
||||
def in_debug_mode(self) -> bool:
|
||||
return self.override_model_configs is not None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> ConversationDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
@ -1283,7 +1566,7 @@ class Message(Base):
|
||||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||
|
||||
@property
|
||||
def message_files(self) -> list[dict[str, Any]]:
|
||||
def message_files(self) -> list[MessageFileInfo]:
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
@ -1338,10 +1621,13 @@ class Message(Base):
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
result: list[dict[str, Any]] = [
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
]
|
||||
result = cast(
|
||||
list[MessageFileInfo],
|
||||
[
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
],
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
return result
|
||||
@ -1351,7 +1637,7 @@ class Message(Base):
|
||||
self._extra_contents = list(contents)
|
||||
|
||||
@property
|
||||
def extra_contents(self) -> list[dict[str, Any]]:
|
||||
def extra_contents(self) -> list[ExtraContentDict]:
|
||||
return getattr(self, "_extra_contents", [])
|
||||
|
||||
@property
|
||||
@ -1367,7 +1653,7 @@ class Message(Base):
|
||||
|
||||
return None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> MessageDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
@ -1391,7 +1677,7 @@ class Message(Base):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Message:
|
||||
def from_dict(cls, data: MessageDict) -> Message:
|
||||
return cls(
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
@ -1451,7 +1737,7 @@ class MessageFeedback(TypeBase):
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"app_id": str(self.app_id),
|
||||
@ -1714,8 +2000,8 @@ class AppMCPServer(TypeBase):
|
||||
return result
|
||||
|
||||
@property
|
||||
def parameters_dict(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], json.loads(self.parameters))
|
||||
def parameters_dict(self) -> dict[str, str]:
|
||||
return cast(dict[str, str], json.loads(self.parameters))
|
||||
|
||||
|
||||
class Site(Base):
|
||||
@ -2155,7 +2441,7 @@ class TraceAppConfig(TypeBase):
|
||||
def tracing_config_str(self) -> str:
|
||||
return json.dumps(self.tracing_config_dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> TraceAppConfigDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
|
||||
@ -35,7 +35,7 @@ dependencies = [
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"markdown~=3.5.1",
|
||||
"markdown~=3.8.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
@ -113,7 +113,7 @@ dev = [
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=38.2.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"basedpyright~=1.31.0",
|
||||
"basedpyright~=1.38.2",
|
||||
"ruff~=0.14.0",
|
||||
"pytest~=8.3.2",
|
||||
"pytest-benchmark~=4.0.0",
|
||||
@ -167,12 +167,12 @@ dev = [
|
||||
"import-linter>=2.3",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
"mypy~=1.17.1",
|
||||
"mypy~=1.19.1",
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.54.0",
|
||||
"pyrefly>=0.55.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
@ -247,3 +247,13 @@ module = [
|
||||
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
|
||||
]
|
||||
ignore_errors = true
|
||||
|
||||
[tool.pyrefly]
|
||||
project-includes = ["."]
|
||||
project-excludes = [
|
||||
".venv",
|
||||
"migrations/",
|
||||
]
|
||||
python-platform = "linux"
|
||||
python-version = "3.11.0"
|
||||
infer-with-first-use = false
|
||||
|
||||
200
api/pyrefly-local-excludes.txt
Normal file
200
api/pyrefly-local-excludes.txt
Normal file
@ -0,0 +1,200 @@
|
||||
configs/middleware/cache/redis_pubsub_config.py
|
||||
controllers/console/app/annotation.py
|
||||
controllers/console/app/app.py
|
||||
controllers/console/app/app_import.py
|
||||
controllers/console/app/mcp_server.py
|
||||
controllers/console/app/site.py
|
||||
controllers/console/auth/email_register.py
|
||||
controllers/console/human_input_form.py
|
||||
controllers/console/init_validate.py
|
||||
controllers/console/ping.py
|
||||
controllers/console/setup.py
|
||||
controllers/console/version.py
|
||||
controllers/console/workspace/trigger_providers.py
|
||||
controllers/service_api/app/annotation.py
|
||||
controllers/web/workflow_events.py
|
||||
core/agent/fc_agent_runner.py
|
||||
core/app/apps/advanced_chat/app_generator.py
|
||||
core/app/apps/advanced_chat/app_runner.py
|
||||
core/app/apps/advanced_chat/generate_task_pipeline.py
|
||||
core/app/apps/agent_chat/app_generator.py
|
||||
core/app/apps/base_app_generate_response_converter.py
|
||||
core/app/apps/base_app_generator.py
|
||||
core/app/apps/chat/app_generator.py
|
||||
core/app/apps/common/workflow_response_converter.py
|
||||
core/app/apps/completion/app_generator.py
|
||||
core/app/apps/pipeline/pipeline_generator.py
|
||||
core/app/apps/pipeline/pipeline_runner.py
|
||||
core/app/apps/workflow/app_generator.py
|
||||
core/app/apps/workflow/app_runner.py
|
||||
core/app/apps/workflow/generate_task_pipeline.py
|
||||
core/app/apps/workflow_app_runner.py
|
||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||
core/datasource/datasource_manager.py
|
||||
core/external_data_tool/api/api.py
|
||||
core/llm_generator/llm_generator.py
|
||||
core/llm_generator/output_parser/structured_output.py
|
||||
core/mcp/mcp_client.py
|
||||
core/ops/aliyun_trace/data_exporter/traceclient.py
|
||||
core/ops/arize_phoenix_trace/arize_phoenix_trace.py
|
||||
core/ops/mlflow_trace/mlflow_trace.py
|
||||
core/ops/ops_trace_manager.py
|
||||
core/ops/tencent_trace/client.py
|
||||
core/ops/tencent_trace/utils.py
|
||||
core/plugin/backwards_invocation/base.py
|
||||
core/plugin/backwards_invocation/model.py
|
||||
core/prompt/utils/extract_thread_messages.py
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
core/rag/datasource/vdb/lindorm/lindorm_vector.py
|
||||
core/rag/datasource/vdb/matrixone/matrixone_vector.py
|
||||
core/rag/datasource/vdb/milvus/milvus_vector.py
|
||||
core/rag/datasource/vdb/myscale/myscale_vector.py
|
||||
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
|
||||
core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
|
||||
core/rag/datasource/vdb/upstash/upstash_vector.py
|
||||
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
|
||||
core/rag/datasource/vdb/weaviate/weaviate_vector.py
|
||||
core/rag/extractor/csv_extractor.py
|
||||
core/rag/extractor/excel_extractor.py
|
||||
core/rag/extractor/firecrawl/firecrawl_app.py
|
||||
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
|
||||
core/rag/extractor/html_extractor.py
|
||||
core/rag/extractor/jina_reader_extractor.py
|
||||
core/rag/extractor/markdown_extractor.py
|
||||
core/rag/extractor/notion_extractor.py
|
||||
core/rag/extractor/pdf_extractor.py
|
||||
core/rag/extractor/text_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_doc_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_eml_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_epub_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_msg_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_xml_extractor.py
|
||||
core/rag/extractor/watercrawl/client.py
|
||||
core/rag/extractor/watercrawl/extractor.py
|
||||
core/rag/extractor/watercrawl/provider.py
|
||||
core/rag/extractor/word_extractor.py
|
||||
core/rag/index_processor/processor/paragraph_index_processor.py
|
||||
core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
core/rag/retrieval/router/multi_dataset_function_call_router.py
|
||||
core/rag/summary_index/summary_index.py
|
||||
core/repositories/sqlalchemy_workflow_execution_repository.py
|
||||
core/repositories/sqlalchemy_workflow_node_execution_repository.py
|
||||
core/tools/__base/tool.py
|
||||
core/tools/mcp_tool/provider.py
|
||||
core/tools/plugin_tool/provider.py
|
||||
core/tools/utils/message_transformer.py
|
||||
core/tools/utils/web_reader_tool.py
|
||||
core/tools/workflow_as_tool/provider.py
|
||||
core/trigger/debug/event_selectors.py
|
||||
core/trigger/entities/entities.py
|
||||
core/trigger/provider.py
|
||||
core/workflow/workflow_entry.py
|
||||
dify_graph/entities/workflow_execution.py
|
||||
dify_graph/file/file_manager.py
|
||||
dify_graph/graph_engine/error_handler.py
|
||||
dify_graph/graph_engine/layers/execution_limits.py
|
||||
dify_graph/nodes/agent/agent_node.py
|
||||
dify_graph/nodes/base/node.py
|
||||
dify_graph/nodes/code/code_node.py
|
||||
dify_graph/nodes/datasource/datasource_node.py
|
||||
dify_graph/nodes/document_extractor/node.py
|
||||
dify_graph/nodes/human_input/human_input_node.py
|
||||
dify_graph/nodes/if_else/if_else_node.py
|
||||
dify_graph/nodes/iteration/iteration_node.py
|
||||
dify_graph/nodes/knowledge_index/knowledge_index_node.py
|
||||
dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
||||
dify_graph/nodes/list_operator/node.py
|
||||
dify_graph/nodes/llm/node.py
|
||||
dify_graph/nodes/loop/loop_node.py
|
||||
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
|
||||
dify_graph/nodes/question_classifier/question_classifier_node.py
|
||||
dify_graph/nodes/start/start_node.py
|
||||
dify_graph/nodes/template_transform/template_transform_node.py
|
||||
dify_graph/nodes/tool/tool_node.py
|
||||
dify_graph/nodes/trigger_plugin/trigger_event_node.py
|
||||
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
|
||||
dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
dify_graph/variables/types.py
|
||||
extensions/ext_fastopenapi.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
factories/variable_factory.py
|
||||
libs/external_api.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
libs/helper.py
|
||||
libs/login.py
|
||||
libs/module_loading.py
|
||||
libs/oauth.py
|
||||
libs/oauth_data_source.py
|
||||
models/trigger.py
|
||||
models/workflow.py
|
||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
services/auth/firecrawl/firecrawl.py
|
||||
services/auth/jina.py
|
||||
services/auth/jina/jina.py
|
||||
services/auth/watercrawl/watercrawl.py
|
||||
services/conversation_service.py
|
||||
services/dataset_service.py
|
||||
services/document_indexing_proxy/document_indexing_task_proxy.py
|
||||
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py
|
||||
services/external_knowledge_service.py
|
||||
services/plugin/plugin_migration.py
|
||||
services/recommend_app/buildin/buildin_retrieval.py
|
||||
services/recommend_app/database/database_retrieval.py
|
||||
services/recommend_app/remote/remote_retrieval.py
|
||||
services/summary_index_service.py
|
||||
services/tools/tools_transform_service.py
|
||||
services/trigger/trigger_provider_service.py
|
||||
services/trigger/trigger_subscription_builder_service.py
|
||||
services/trigger/webhook_service.py
|
||||
services/workflow_draft_variable_service.py
|
||||
services/workflow_event_snapshot_service.py
|
||||
services/workflow_service.py
|
||||
tasks/app_generate/workflow_execute_task.py
|
||||
tasks/regenerate_summary_index_task.py
|
||||
tasks/trigger_processing_tasks.py
|
||||
tasks/workflow_cfs_scheduler/cfs_scheduler.py
|
||||
tasks/workflow_execution_tasks.py
|
||||
@ -1,8 +0,0 @@
|
||||
project-includes = ["."]
|
||||
project-excludes = [
|
||||
".venv",
|
||||
"migrations/",
|
||||
]
|
||||
python-platform = "linux"
|
||||
python-version = "3.11.0"
|
||||
infer-with-first-use = false
|
||||
@ -1,5 +1,6 @@
|
||||
[pytest]
|
||||
addopts = --cov=./api --cov-report=json
|
||||
pythonpath = .
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib
|
||||
env =
|
||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
||||
@ -19,7 +20,7 @@ env =
|
||||
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
|
||||
HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
|
||||
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a
|
||||
MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa
|
||||
MOCK_SWITCH = true
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, IconType
|
||||
from models.model import AppModelConfig, AppModelConfigDict, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
@ -523,7 +524,7 @@ class AppDslService:
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(model_config)
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import TypedDict, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
@ -187,7 +187,7 @@ class AppService:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
@ -388,7 +388,7 @@ class AppService:
|
||||
agent_config = app_model_config.agent_mode_dict
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get("tools", [])
|
||||
tools = cast(list[dict[str, Any]], agent_config.get("tools", []))
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from werkzeug.datastructures import FileStorage
|
||||
@ -106,7 +107,7 @@ class AudioService:
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = text_to_speech_dict.get("voice")
|
||||
voice = cast(str | None, text_to_speech_dict.get("voice"))
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
|
||||
@ -63,7 +63,12 @@ class RagPipelineTransformService:
|
||||
):
|
||||
node = self._deal_file_extensions(node)
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if dataset.tenant_id != current_user.current_tenant_id:
|
||||
raise ValueError("Unauthorized")
|
||||
node = self._deal_knowledge_index(
|
||||
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
|
||||
)
|
||||
new_nodes.append(node)
|
||||
if new_nodes:
|
||||
graph["nodes"] = new_nodes
|
||||
@ -155,14 +160,13 @@ class RagPipelineTransformService:
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self,
|
||||
knowledge_configuration: KnowledgeConfiguration,
|
||||
dataset: Dataset,
|
||||
doc_form: str,
|
||||
indexing_technique: str | None,
|
||||
retrieval_model: RetrievalSetting | None,
|
||||
node: dict,
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
|
||||
304
api/services/retention/conversation/message_export_service.py
Normal file
304
api/services/retention/conversation/message_export_service.py
Normal file
@ -0,0 +1,304 @@
|
||||
"""
|
||||
Export app messages to JSONL.GZ format.
|
||||
|
||||
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
|
||||
retriever_resources (from message_metadata), feedback (user feedbacks array).
|
||||
|
||||
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
|
||||
Does NOT touch Message.inputs / Message.user_feedback properties.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, BinaryIO, cast
|
||||
|
||||
import orjson
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, tuple_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message, MessageFeedback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_FILENAME_BASE_LENGTH = 1024
|
||||
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
|
||||
|
||||
|
||||
class AppMessageExportFeedback(BaseModel):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportRecord(BaseModel):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
query: str
|
||||
answer: str
|
||||
inputs: dict[str, Any]
|
||||
retriever_resources: list[Any] = Field(default_factory=list)
|
||||
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportStats(BaseModel):
|
||||
batches: int = 0
|
||||
total_messages: int = 0
|
||||
messages_with_feedback: int = 0
|
||||
total_feedbacks: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportService:
|
||||
@staticmethod
|
||||
def validate_export_filename(filename: str) -> str:
|
||||
normalized = filename.strip()
|
||||
if not normalized:
|
||||
raise ValueError("--filename must not be empty.")
|
||||
|
||||
normalized_lower = normalized.lower()
|
||||
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
|
||||
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
|
||||
|
||||
if normalized.startswith("/"):
|
||||
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
|
||||
|
||||
if "\\" in normalized:
|
||||
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
|
||||
|
||||
if "//" in normalized:
|
||||
raise ValueError("--filename must not contain empty path segments ('//').")
|
||||
|
||||
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
|
||||
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
|
||||
|
||||
for ch in normalized:
|
||||
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
|
||||
raise ValueError("--filename must not contain control characters or NUL.")
|
||||
|
||||
parts = PurePosixPath(normalized).parts
|
||||
if not parts:
|
||||
raise ValueError("--filename must include a file name.")
|
||||
|
||||
if any(part in (".", "..") for part in parts):
|
||||
raise ValueError("--filename must not contain '.' or '..' path segments.")
|
||||
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def output_gz_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl.gz"
|
||||
|
||||
@property
|
||||
def output_jsonl_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
*,
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
use_cloud_storage: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if start_from and start_from >= end_before:
|
||||
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
|
||||
|
||||
self._app_id = app_id
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._filename_base = self.validate_export_filename(filename)
|
||||
self._batch_size = batch_size
|
||||
self._use_cloud_storage = use_cloud_storage
|
||||
self._dry_run = dry_run
|
||||
|
||||
def run(self) -> AppMessageExportStats:
|
||||
stats = AppMessageExportStats()
|
||||
|
||||
logger.info(
|
||||
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
|
||||
self._app_id,
|
||||
self._start_from,
|
||||
self._end_before,
|
||||
self._dry_run,
|
||||
self._use_cloud_storage,
|
||||
self.output_gz_name,
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
for _ in self._iter_records_with_stats(stats):
|
||||
pass
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
if self._use_cloud_storage:
|
||||
self._export_to_cloud(stats)
|
||||
else:
|
||||
self._export_to_local(stats)
|
||||
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for batch in self._iter_record_batches():
|
||||
yield from batch
|
||||
|
||||
@staticmethod
|
||||
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
|
||||
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
|
||||
for record in records:
|
||||
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
|
||||
|
||||
def _export_to_local(self, stats: AppMessageExportStats) -> None:
|
||||
output_path = Path.cwd() / self.output_gz_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("wb") as output_file:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
|
||||
|
||||
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
|
||||
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
|
||||
tmp.seek(0)
|
||||
data = tmp.read()
|
||||
|
||||
storage.save(self.output_gz_name, data)
|
||||
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
|
||||
|
||||
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for record in self.iter_records():
|
||||
self._update_stats(stats, record)
|
||||
yield record
|
||||
|
||||
@staticmethod
|
||||
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
|
||||
stats.total_messages += 1
|
||||
if record.feedback:
|
||||
stats.messages_with_feedback += 1
|
||||
stats.total_feedbacks += len(record.feedback)
|
||||
|
||||
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
|
||||
if stats.total_messages == 0:
|
||||
stats.batches = 0
|
||||
return
|
||||
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
|
||||
|
||||
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
|
||||
cursor: tuple[datetime.datetime, str] | None = None
|
||||
while True:
|
||||
rows, cursor = self._fetch_batch(cursor)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
message_ids = [str(row.id) for row in rows]
|
||||
feedbacks_map = self._fetch_feedbacks(message_ids)
|
||||
yield [self._build_record(row, feedbacks_map) for row in rows]
|
||||
|
||||
def _fetch_batch(
|
||||
self, cursor: tuple[datetime.datetime, str] | None
|
||||
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(
|
||||
Message.id,
|
||||
Message.conversation_id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message._inputs, # pyright: ignore[reportPrivateUsage]
|
||||
Message.message_metadata,
|
||||
Message.created_at,
|
||||
)
|
||||
.where(
|
||||
Message.app_id == self._app_id,
|
||||
Message.created_at < self._end_before,
|
||||
)
|
||||
.order_by(Message.created_at, Message.id)
|
||||
.limit(self._batch_size)
|
||||
)
|
||||
|
||||
if self._start_from:
|
||||
stmt = stmt.where(Message.created_at >= self._start_from)
|
||||
|
||||
if cursor:
|
||||
stmt = stmt.where(
|
||||
tuple_(Message.created_at, Message.id)
|
||||
> tuple_(
|
||||
sa.literal(cursor[0], type_=sa.DateTime()),
|
||||
sa.literal(cursor[1], type_=Message.id.type),
|
||||
)
|
||||
)
|
||||
|
||||
rows = list(session.execute(stmt).all())
|
||||
|
||||
if not rows:
|
||||
return [], cursor
|
||||
|
||||
last = rows[-1]
|
||||
return rows, (last.created_at, last.id)
|
||||
|
||||
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.message_id.in_(message_ids),
|
||||
MessageFeedback.from_source == "user",
|
||||
)
|
||||
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
|
||||
)
|
||||
feedbacks = list(session.scalars(stmt).all())
|
||||
|
||||
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
|
||||
for feedback in feedbacks:
|
||||
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
|
||||
retriever_resources: list[Any] = []
|
||||
if row.message_metadata:
|
||||
try:
|
||||
metadata = json.loads(row.message_metadata)
|
||||
value = metadata.get("retriever_resources", [])
|
||||
if isinstance(value, list):
|
||||
retriever_resources = value
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
message_id = str(row.id)
|
||||
return AppMessageExportRecord(
|
||||
conversation_id=str(row.conversation_id),
|
||||
message_id=message_id,
|
||||
query=row.query,
|
||||
answer=row.answer,
|
||||
inputs=row._inputs if isinstance(row._inputs, dict) else {},
|
||||
retriever_resources=retriever_resources,
|
||||
feedback=feedbacks_map.get(message_id, []),
|
||||
)
|
||||
@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||
from models.enums import AppTriggerType, CreatorUserRole
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from services.plugin.plugin_service import PluginService
|
||||
@ -132,7 +132,14 @@ class WorkflowAppService:
|
||||
),
|
||||
)
|
||||
if created_by_account:
|
||||
account = session.scalar(select(Account).where(Account.email == created_by_account))
|
||||
account = session.scalar(
|
||||
select(Account)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(
|
||||
Account.email == created_by_account,
|
||||
TenantAccountJoin.tenant_id == app_model.tenant_id,
|
||||
)
|
||||
)
|
||||
if not account:
|
||||
raise ValueError(f"Account not found: {created_by_account}")
|
||||
|
||||
|
||||
@ -11,13 +11,13 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import ErrorStrategy, UserFrom, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.errors import WorkflowNodeRunFailedError
|
||||
from dify_graph.file import File
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
@ -1063,13 +1063,15 @@ class WorkflowService:
|
||||
variable_pool: VariablePool,
|
||||
) -> HumanInputNode:
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=account.id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=account.id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
|
||||
@ -14,7 +14,7 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
@shared_task(queue="dataset_summary")
|
||||
def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None):
|
||||
"""
|
||||
Async generate summary index for document segments.
|
||||
|
||||
@ -6,7 +6,6 @@ import typing
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.helper.marketplace import record_install_plugin_event
|
||||
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
@ -166,7 +165,6 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
# execute upgrade
|
||||
new_unique_identifier = manifest.latest_package_identifier
|
||||
|
||||
record_install_plugin_event(new_unique_identifier)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}",
|
||||
|
||||
@ -16,7 +16,7 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
@shared_task(queue="dataset_summary")
|
||||
def regenerate_summary_index_task(
|
||||
dataset_id: str,
|
||||
regenerate_reason: str = "summary_model_changed",
|
||||
|
||||
@ -5,14 +5,10 @@ This test module validates the 400-character limit enforcement
|
||||
for App descriptions across all creation and editing endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the API root to Python path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
|
||||
|
||||
class TestAppDescriptionValidationUnit:
|
||||
"""Unit tests for description validation function"""
|
||||
|
||||
@ -4,10 +4,9 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.code.code_node import CodeNode
|
||||
@ -15,6 +14,7 @@ from dify_graph.nodes.code.limits import CodeNodeLimits
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH
|
||||
|
||||
@ -31,11 +31,11 @@ def init_code_node(code_config: dict):
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
||||
@ -5,18 +5,18 @@ from urllib.parse import urlencode
|
||||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
HTTP_REQUEST_CONFIG = HttpRequestNodeConfig(
|
||||
max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
@ -41,11 +41,11 @@ def init_http_node(config: dict):
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
@ -685,11 +685,11 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
||||
@ -4,17 +4,17 @@ import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
|
||||
@ -37,11 +37,11 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
|
||||
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
||||
@ -3,10 +3,9 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
@ -14,6 +13,7 @@ from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
@ -43,11 +43,11 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError
|
||||
from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
class _SimpleJinja2Renderer:
|
||||
@ -53,11 +53,11 @@ def test_execute_template_transform():
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
||||
@ -2,16 +2,16 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.enums import UserFrom, WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def init_tool_node(config: dict):
|
||||
@ -26,11 +26,11 @@ def init_tool_node(config: dict):
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
init_params = build_test_graph_init_params(
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
|
||||
@ -10,8 +10,11 @@ more reliable and realistic test scenarios.
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Protocol, TypeVar
|
||||
|
||||
import psycopg2
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
@ -31,6 +34,25 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _CloserProtocol(Protocol):
|
||||
"""_Closer is any type which implement the close() method."""
|
||||
|
||||
def close(self):
|
||||
"""close the current object, release any external resouece (file, transaction, connection etc.)
|
||||
associated with it.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
_Closer = TypeVar("_Closer", bound=_CloserProtocol)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]:
|
||||
yield closer
|
||||
closer.close()
|
||||
|
||||
|
||||
class DifyTestContainers:
|
||||
"""
|
||||
Manages all test containers required for Dify integration tests.
|
||||
@ -97,45 +119,28 @@ class DifyTestContainers:
|
||||
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
|
||||
logger.info("PostgreSQL container is ready and accepting connections")
|
||||
|
||||
# Install uuid-ossp extension for UUID generation
|
||||
logger.info("Installing uuid-ossp extension...")
|
||||
try:
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=self.postgres.username,
|
||||
password=self.postgres.password,
|
||||
database=self.postgres.dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
|
||||
cursor.close()
|
||||
conn.close()
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=self.postgres.username,
|
||||
password=self.postgres.password,
|
||||
database=self.postgres.dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
with _auto_close(conn):
|
||||
with conn.cursor() as cursor:
|
||||
# Install uuid-ossp extension for UUID generation
|
||||
logger.info("Installing uuid-ossp extension...")
|
||||
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
|
||||
logger.info("uuid-ossp extension installed successfully")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to install uuid-ossp extension: %s", e)
|
||||
|
||||
# Create plugin database for dify-plugin-daemon
|
||||
logger.info("Creating plugin database...")
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=self.postgres.username,
|
||||
password=self.postgres.password,
|
||||
database=self.postgres.dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("CREATE DATABASE dify_plugin;")
|
||||
cursor.close()
|
||||
conn.close()
|
||||
# NOTE: We cannot use `with conn.cursor() as cursor:` as it will wrap the statement
|
||||
# inside a transaction. However, the `CREATE DATABASE` statement cannot run inside a transaction block.
|
||||
with _auto_close(conn.cursor()) as cursor:
|
||||
# Create plugin database for dify-plugin-daemon
|
||||
logger.info("Creating plugin database...")
|
||||
cursor.execute("CREATE DATABASE dify_plugin;")
|
||||
logger.info("Plugin database created successfully")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create plugin database: %s", e)
|
||||
|
||||
# Set up storage environment variables
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
@ -258,23 +263,16 @@ class DifyTestContainers:
|
||||
containers = [self.redis, self.postgres, self.dify_sandbox, self.dify_plugin_daemon]
|
||||
for container in containers:
|
||||
if container:
|
||||
try:
|
||||
container_name = container.image
|
||||
logger.info("Stopping container: %s", container_name)
|
||||
container.stop()
|
||||
logger.info("Successfully stopped container: %s", container_name)
|
||||
except Exception as e:
|
||||
# Log error but don't fail the test cleanup
|
||||
logger.warning("Failed to stop container %s: %s", container, e)
|
||||
container_name = container.image
|
||||
logger.info("Stopping container: %s", container_name)
|
||||
container.stop()
|
||||
logger.info("Successfully stopped container: %s", container_name)
|
||||
|
||||
# Stop and remove the network
|
||||
if self.network:
|
||||
try:
|
||||
logger.info("Removing Docker network...")
|
||||
self.network.remove()
|
||||
logger.info("Successfully removed Docker network")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove Docker network: %s", e)
|
||||
logger.info("Removing Docker network...")
|
||||
self.network.remove()
|
||||
logger.info("Successfully removed Docker network")
|
||||
|
||||
self._containers_started = False
|
||||
logger.info("All test containers stopped and cleaned up successfully")
|
||||
|
||||
@ -12,7 +12,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
|
||||
from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.enums import WorkflowType
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
@ -33,6 +32,7 @@ from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode, IconType
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
|
||||
@ -87,11 +87,11 @@ def _build_graph(
|
||||
form_repository: HumanInputFormRepository,
|
||||
) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
params = build_test_graph_init_params(
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
|
||||
@ -0,0 +1,497 @@
|
||||
"""
|
||||
Container-backed integration tests for dataset permission services on the real SQL path.
|
||||
|
||||
This module exercises persisted DatasetPermission rows and dataset permission
|
||||
checks with testcontainers-backed infrastructure instead of database-chain mocks.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
Dataset,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
)
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetPermissionTestDataFactory:
|
||||
"""Create persisted entities and request payloads for dataset permission integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
tenant: Tenant | None = None,
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""Create a real account and tenant with specified role."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
if tenant is None:
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db.session.add_all([account, tenant])
|
||||
else:
|
||||
db.session.add(account)
|
||||
|
||||
db.session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
name: str = "Test Dataset",
|
||||
) -> Dataset:
|
||||
"""Create a real dataset with specified attributes."""
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
provider="vendor",
|
||||
retrieval_model={"top_k": 2},
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission(
|
||||
dataset_id: str,
|
||||
account_id: str,
|
||||
tenant_id: str,
|
||||
has_permission: bool = True,
|
||||
) -> DatasetPermission:
|
||||
"""Create a real DatasetPermission instance."""
|
||||
permission = DatasetPermission(
|
||||
dataset_id=dataset_id,
|
||||
account_id=account_id,
|
||||
tenant_id=tenant_id,
|
||||
has_permission=has_permission,
|
||||
)
|
||||
db.session.add(permission)
|
||||
db.session.commit()
|
||||
return permission
|
||||
|
||||
@staticmethod
|
||||
def build_user_list_payload(user_ids: list[str]) -> list[dict[str, str]]:
|
||||
"""Build the request payload shape used by partial-member list updates."""
|
||||
return [{"user_id": user_id} for user_id in user_ids]
|
||||
|
||||
|
||||
class TestDatasetPermissionServiceGetPartialMemberList:
|
||||
"""Verify partial-member list reads against persisted DatasetPermission rows."""
|
||||
|
||||
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers):
|
||||
"""
|
||||
Test retrieving partial member list with multiple members.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
user_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
user_3, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
expected_account_ids = [user_1.id, user_2.id, user_3.id]
|
||||
for account_id in expected_account_ids:
|
||||
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, account_id, tenant.id)
|
||||
|
||||
# Act
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
assert set(result) == set(expected_account_ids)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers):
|
||||
"""
|
||||
Test retrieving partial member list with single member.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
expected_account_ids = [user.id]
|
||||
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
|
||||
|
||||
# Act
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
assert set(result) == set(expected_account_ids)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers):
|
||||
"""
|
||||
Test retrieving partial member list when no members exist.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
# Act
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestDatasetPermissionServiceUpdatePartialMemberList:
|
||||
"""Verify partial-member list updates against persisted DatasetPermission rows."""
|
||||
|
||||
def test_update_partial_member_list_add_new_members(self, db_session_with_containers):
|
||||
"""
|
||||
Test adding new partial members to a dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
user_list = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert set(result) == {member_1.id, member_2.id}
|
||||
|
||||
def test_update_partial_member_list_replace_existing(self, db_session_with_containers):
|
||||
"""
|
||||
Test replacing existing partial members with new ones.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
old_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
old_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
new_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
new_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
old_users = DatasetPermissionTestDataFactory.build_user_list_payload([old_member_1.id, old_member_2.id])
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, old_users)
|
||||
|
||||
new_users = DatasetPermissionTestDataFactory.build_user_list_payload([new_member_1.id, new_member_2.id])
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, new_users)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert set(result) == {new_member_1.id, new_member_2.id}
|
||||
|
||||
def test_update_partial_member_list_empty_list(self, db_session_with_containers):
|
||||
"""
|
||||
Test updating with empty member list (clearing all members).
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, [])
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers):
|
||||
"""
|
||||
Test error handling and rollback on database error.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
existing_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
replacement_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant.id,
|
||||
dataset.id,
|
||||
DatasetPermissionTestDataFactory.build_user_list_payload([existing_member.id]),
|
||||
)
|
||||
user_list = DatasetPermissionTestDataFactory.build_user_list_payload([replacement_member.id])
|
||||
rollback_called = {"count": 0}
|
||||
original_rollback = db.session.rollback
|
||||
|
||||
# Act / Assert
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
|
||||
def _raise_commit():
|
||||
raise Exception("Database connection error")
|
||||
|
||||
def _rollback_and_mark():
|
||||
rollback_called["count"] += 1
|
||||
original_rollback()
|
||||
|
||||
mp.setattr("services.dataset_service.db.session.commit", _raise_commit)
|
||||
mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark)
|
||||
with pytest.raises(Exception, match="Database connection error"):
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert rollback_called["count"] == 1
|
||||
assert result == [existing_member.id]
|
||||
assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 1
|
||||
|
||||
|
||||
class TestDatasetPermissionServiceClearPartialMemberList:
|
||||
"""Verify partial-member clearing against persisted DatasetPermission rows."""
|
||||
|
||||
def test_clear_partial_member_list_success(self, db_session_with_containers):
|
||||
"""
|
||||
Test successful clearing of partial member list.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.clear_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_clear_partial_member_list_empty_list(self, db_session_with_containers):
|
||||
"""
|
||||
Test clearing partial member list when no members exist.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.clear_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers):
|
||||
"""
|
||||
Test error handling and rollback on database error.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
|
||||
rollback_called = {"count": 0}
|
||||
original_rollback = db.session.rollback
|
||||
|
||||
# Act / Assert
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
|
||||
def _raise_commit():
|
||||
raise Exception("Database connection error")
|
||||
|
||||
def _rollback_and_mark():
|
||||
rollback_called["count"] += 1
|
||||
original_rollback()
|
||||
|
||||
mp.setattr("services.dataset_service.db.session.commit", _raise_commit)
|
||||
mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark)
|
||||
with pytest.raises(Exception, match="Database connection error"):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert rollback_called["count"] == 1
|
||||
assert set(result) == {member_1.id, member_2.id}
|
||||
assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 2
|
||||
|
||||
|
||||
class TestDatasetServiceCheckDatasetPermission:
|
||||
"""Verify dataset access checks against persisted partial-member permissions."""
|
||||
|
||||
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
|
||||
"""
|
||||
Test that user with explicit permission can access partial_members dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id,
|
||||
owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
# Assert
|
||||
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert user.id in permissions
|
||||
|
||||
def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers):
|
||||
"""
|
||||
Test error when user without permission tries to access partial_members dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id,
|
||||
owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
|
||||
class TestDatasetServiceCheckDatasetOperatorPermission:
|
||||
"""Verify operator permission checks against persisted partial-member permissions."""
|
||||
|
||||
def test_check_dataset_operator_permission_partial_members_with_permission_success(
|
||||
self, db_session_with_containers
|
||||
):
|
||||
"""
|
||||
Test that user with explicit permission can access partial_members dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id,
|
||||
owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
|
||||
# Assert
|
||||
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert user.id in permissions
|
||||
|
||||
def test_check_dataset_operator_permission_partial_members_without_permission_error(
|
||||
self, db_session_with_containers
|
||||
):
|
||||
"""
|
||||
Test error when user without permission tries to access partial_members dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id,
|
||||
owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
@ -0,0 +1,244 @@
|
||||
"""Container-backed integration tests for DatasetService.delete_dataset real SQL paths."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class DatasetDeleteIntegrationDataFactory:
|
||||
"""Create persisted entities used by delete_dataset integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(db_session_with_containers) -> tuple[Account, Tenant]:
|
||||
"""Persist an owner account, tenant, and tenant join for dataset deletion tests."""
|
||||
account = Account(
|
||||
email=f"owner-{uuid4()}@example.com",
|
||||
name="Owner",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=f"tenant-{uuid4()}",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
*,
|
||||
indexing_technique: str | None,
|
||||
chunk_structure: str | None,
|
||||
index_struct: str | None = '{"type": "paragraph"}',
|
||||
collection_binding_id: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
) -> Dataset:
|
||||
"""Persist a dataset with delete_dataset-relevant fields configured."""
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"dataset-{uuid4()}",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique=indexing_technique,
|
||||
index_struct=index_struct,
|
||||
created_by=created_by,
|
||||
collection_binding_id=collection_binding_id,
|
||||
pipeline_id=pipeline_id,
|
||||
chunk_structure=chunk_structure,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
created_by: str,
|
||||
doc_form: str = "text_model",
|
||||
) -> Document:
|
||||
"""Persist a document so dataset.doc_form resolves through the real document path."""
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
batch=f"batch-{uuid4()}",
|
||||
name="Document",
|
||||
created_from="upload_file",
|
||||
created_by=created_by,
|
||||
doc_form=doc_form,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
return document
|
||||
|
||||
|
||||
class TestDatasetServiceDeleteDataset:
|
||||
"""Integration coverage for DatasetService.delete_dataset using testcontainers."""
|
||||
|
||||
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
|
||||
"""Delete a dataset with documents and dispatch cleanup through the real signal handler."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetDeleteIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
indexing_technique="high_quality",
|
||||
chunk_structure=None,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid4()),
|
||||
pipeline_id=str(uuid4()),
|
||||
)
|
||||
DatasetDeleteIntegrationDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=owner.id,
|
||||
doc_form="text_model",
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay",
|
||||
autospec=True,
|
||||
) as clean_dataset_delay:
|
||||
result = DatasetService.delete_dataset(dataset.id, owner)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
assert result is True
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
clean_dataset_delay.assert_called_once_with(
|
||||
dataset.id,
|
||||
dataset.tenant_id,
|
||||
dataset.indexing_technique,
|
||||
dataset.index_struct,
|
||||
dataset.collection_binding_id,
|
||||
dataset.doc_form,
|
||||
dataset.pipeline_id,
|
||||
)
|
||||
|
||||
def test_delete_empty_dataset_success(self, db_session_with_containers):
|
||||
"""Delete an empty dataset without scheduling cleanup when both gating fields are absent."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetDeleteIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
indexing_technique=None,
|
||||
chunk_structure=None,
|
||||
index_struct=None,
|
||||
collection_binding_id=None,
|
||||
pipeline_id=None,
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay",
|
||||
autospec=True,
|
||||
) as clean_dataset_delay:
|
||||
result = DatasetService.delete_dataset(dataset.id, owner)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
assert result is True
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
clean_dataset_delay.assert_not_called()
|
||||
|
||||
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
|
||||
"""Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetDeleteIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
indexing_technique=None,
|
||||
chunk_structure="text_model",
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid4()),
|
||||
pipeline_id=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay",
|
||||
autospec=True,
|
||||
) as clean_dataset_delay:
|
||||
result = DatasetService.delete_dataset(dataset.id, owner)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
assert result is True
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
clean_dataset_delay.assert_not_called()
|
||||
|
||||
def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers):
|
||||
"""Delete a dataset without cleanup when indexing exists but doc_form resolves to None."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetDeleteIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
indexing_technique="high_quality",
|
||||
chunk_structure=None,
|
||||
index_struct='{"type": "paragraph"}',
|
||||
collection_binding_id=str(uuid4()),
|
||||
pipeline_id=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay",
|
||||
autospec=True,
|
||||
) as clean_dataset_delay:
|
||||
result = DatasetService.delete_dataset(dataset.id, owner)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
assert result is True
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
clean_dataset_delay.assert_not_called()
|
||||
|
||||
def test_delete_dataset_not_found(self, db_session_with_containers):
|
||||
"""Return False without scheduling cleanup when the target dataset does not exist."""
|
||||
# Arrange
|
||||
owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
missing_dataset_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay",
|
||||
autospec=True,
|
||||
) as clean_dataset_delay:
|
||||
result = DatasetService.delete_dataset(missing_dataset_id, owner)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
clean_dataset_delay.assert_not_called()
|
||||
@ -0,0 +1,233 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
|
||||
|
||||
|
||||
class TestAppMessageExportServiceIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
yield
|
||||
db_session_with_containers.query(DatasetRetrieverResource).delete()
|
||||
db_session_with_containers.query(AppAnnotationHitHistory).delete()
|
||||
db_session_with_containers.query(SavedMessage).delete()
|
||||
db_session_with_containers.query(MessageFile).delete()
|
||||
db_session_with_containers.query(MessageAgentThought).delete()
|
||||
db_session_with_containers.query(MessageChain).delete()
|
||||
db_session_with_containers.query(MessageAnnotation).delete()
|
||||
db_session_with_containers.query(MessageFeedback).delete()
|
||||
db_session_with_containers.query(Message).delete()
|
||||
db_session_with_containers.query(Conversation).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@staticmethod
|
||||
def _create_app_context(session: Session) -> tuple[App, Conversation]:
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="tester",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
session.add(join)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="export-app",
|
||||
description="integration test app",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=60,
|
||||
api_rph=3600,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=str(uuid.uuid4()),
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
mode="chat",
|
||||
name="conv",
|
||||
inputs={"seed": 1},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
return app, conversation
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
session: Session,
|
||||
app: App,
|
||||
conversation: Conversation,
|
||||
created_at: datetime.datetime,
|
||||
*,
|
||||
query: str,
|
||||
answer: str,
|
||||
inputs: dict,
|
||||
message_metadata: str | None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
answer=answer,
|
||||
message=[{"role": "assistant", "content": answer}],
|
||||
message_tokens=10,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=20,
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
message_metadata=message_metadata,
|
||||
from_source="api",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
session.add(message)
|
||||
session.flush()
|
||||
return message
|
||||
|
||||
def test_iter_records_with_stats(self, db_session_with_containers: Session):
|
||||
app, conversation = self._create_app_context(db_session_with_containers)
|
||||
|
||||
first_inputs = {
|
||||
"plain": "v1",
|
||||
"nested": {"a": 1, "b": [1, {"x": True}]},
|
||||
"list": ["x", 2, {"y": "z"}],
|
||||
}
|
||||
second_inputs = {"other": "value", "items": [1, 2, 3]}
|
||||
|
||||
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
|
||||
first_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time,
|
||||
query="q1",
|
||||
answer="a1",
|
||||
inputs=first_inputs,
|
||||
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
|
||||
)
|
||||
second_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time + datetime.timedelta(minutes=1),
|
||||
query="q2",
|
||||
answer="a2",
|
||||
inputs=second_inputs,
|
||||
message_metadata=None,
|
||||
)
|
||||
|
||||
user_feedback_1 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content="first",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
user_feedback_2 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
content="second",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
admin_feedback = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="admin",
|
||||
content="should-be-filtered",
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
|
||||
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
|
||||
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
|
||||
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = AppMessageExportService(
|
||||
app_id=app.id,
|
||||
start_from=base_time - datetime.timedelta(minutes=1),
|
||||
end_before=base_time + datetime.timedelta(minutes=10),
|
||||
filename="unused",
|
||||
batch_size=1,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = AppMessageExportStats()
|
||||
records = list(service._iter_records_with_stats(stats))
|
||||
service._finalize_stats(stats)
|
||||
|
||||
assert len(records) == 2
|
||||
assert records[0].message_id == first_message.id
|
||||
assert records[1].message_id == second_message.id
|
||||
|
||||
assert records[0].inputs == first_inputs
|
||||
assert records[1].inputs == second_inputs
|
||||
|
||||
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
|
||||
assert records[1].retriever_resources == []
|
||||
|
||||
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
|
||||
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
|
||||
assert records[1].feedback == []
|
||||
|
||||
assert stats.batches == 2
|
||||
assert stats.total_messages == 2
|
||||
assert stats.messages_with_feedback == 1
|
||||
assert stats.total_feedbacks == 2
|
||||
@ -32,11 +32,6 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs")
|
||||
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
|
||||
# Add the API directory to Python path to ensure proper imports
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, PROJECT_DIR)
|
||||
|
||||
from core.db.session_factory import configure_session_factory, session_factory
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
@ -9,8 +9,16 @@ import pytest
|
||||
|
||||
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueuePingEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models.enums import MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.model import EndUser
|
||||
@ -185,3 +193,97 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
|
||||
assert message.answer == "beforeafter"
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
|
||||
|
||||
def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_id = "workflow-1"
|
||||
pipeline._ensure_workflow_initialized = mock.Mock()
|
||||
runtime_state = SimpleNamespace()
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
|
||||
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
|
||||
)
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
|
||||
event=StreamEvent.WORKFLOW_FINISHED,
|
||||
data=SimpleNamespace(status=WorkflowExecutionStatus.SUCCEEDED),
|
||||
)
|
||||
|
||||
event = QueueWorkflowSucceededEvent(outputs={})
|
||||
responses = list(pipeline._handle_workflow_succeeded_event(event))
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
|
||||
|
||||
|
||||
def test_workflow_partial_success_emits_message_end_before_workflow_finished() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_id = "workflow-1"
|
||||
pipeline._ensure_workflow_initialized = mock.Mock()
|
||||
runtime_state = SimpleNamespace()
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
|
||||
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
|
||||
)
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
|
||||
event=StreamEvent.WORKFLOW_FINISHED,
|
||||
data=SimpleNamespace(status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED),
|
||||
)
|
||||
|
||||
event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
responses = list(pipeline._handle_workflow_partial_success_event(event))
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
|
||||
|
||||
|
||||
def test_process_stream_response_breaks_after_workflow_succeeded() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
succeeded_event = QueueWorkflowSucceededEvent(outputs={})
|
||||
ping_event = QueuePingEvent()
|
||||
queue_messages = [
|
||||
SimpleNamespace(event=succeeded_event),
|
||||
SimpleNamespace(event=ping_event),
|
||||
]
|
||||
|
||||
pipeline._conversation_name_generate_thread = None
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
|
||||
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
|
||||
pipeline._handle_workflow_succeeded_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
|
||||
)
|
||||
|
||||
responses = list(pipeline._process_stream_response())
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
|
||||
pipeline._handle_workflow_succeeded_event.assert_called_once_with(succeeded_event, trace_manager=None)
|
||||
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()
|
||||
|
||||
|
||||
def test_process_stream_response_breaks_after_workflow_partial_success() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
partial_event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
ping_event = QueuePingEvent()
|
||||
queue_messages = [
|
||||
SimpleNamespace(event=partial_event),
|
||||
SimpleNamespace(event=ping_event),
|
||||
]
|
||||
|
||||
pipeline._conversation_name_generate_thread = None
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
|
||||
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
|
||||
pipeline._handle_workflow_partial_success_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
|
||||
)
|
||||
|
||||
responses = list(pipeline._process_stream_response())
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
|
||||
pipeline._handle_workflow_partial_success_event.assert_called_once_with(partial_event, trace_manager=None)
|
||||
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()
|
||||
|
||||
@ -1,19 +1,13 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
API_DIR = str(Path(__file__).resolve().parents[5])
|
||||
if API_DIR not in sys.path:
|
||||
sys.path.insert(0, API_DIR)
|
||||
|
||||
import dify_graph.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.pause_reason import SchedulingPause
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@ -34,6 +28,7 @@ from dify_graph.nodes.end.entities import EndNodeData
|
||||
from dify_graph.nodes.start.entities import StartNodeData
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
ops_stub = ModuleType("core.ops.ops_trace_manager")
|
||||
@ -142,11 +137,11 @@ def _build_graph_config(*, pause_on: str | None) -> dict[str, object]:
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph:
|
||||
graph_config = _build_graph_config(pause_on=pause_on)
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
params = build_test_graph_init_params(
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
|
||||
@ -0,0 +1,425 @@
|
||||
"""
|
||||
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
|
||||
|
||||
This test suite ensures that the files array is correctly populated in the message_end
|
||||
SSE event, which is critical for vision/image chat responses to render correctly.
|
||||
|
||||
Test Coverage:
|
||||
- Files array populated when MessageFile records exist
|
||||
- Files array is None when no MessageFile records exist
|
||||
- Correct signed URL generation for LOCAL_FILE transfer method
|
||||
- Correct URL handling for REMOTE_URL transfer method
|
||||
- Correct URL handling for TOOL_FILE transfer method
|
||||
- Proper file metadata formatting (filename, mime_type, size, extension)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.task_entities import MessageEndStreamResponse
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
|
||||
class TestMessageEndStreamResponseFiles:
|
||||
"""Test suite for files array population in message_end SSE event."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline(self):
|
||||
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
|
||||
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
|
||||
pipeline._message_id = str(uuid.uuid4())
|
||||
pipeline._task_state = Mock()
|
||||
pipeline._task_state.metadata = Mock()
|
||||
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
|
||||
pipeline._task_state.llm_result = Mock()
|
||||
pipeline._task_state.llm_result.usage = Mock()
|
||||
pipeline._application_generate_entity = Mock()
|
||||
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
|
||||
return pipeline
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_local(self):
|
||||
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.url = None
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_remote(self):
|
||||
"""Create a mock MessageFile with REMOTE_URL transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "https://example.com/image.jpg"
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_tool(self):
|
||||
"""Create a mock MessageFile with TOOL_FILE transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "tool_file_123.png"
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_upload_file(self, mock_message_file_local):
|
||||
"""Create a mock UploadFile."""
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.id = mock_message_file_local.upload_file_id
|
||||
upload_file.name = "test_image.png"
|
||||
upload_file.mime_type = "image/png"
|
||||
upload_file.size = 1024
|
||||
upload_file.extension = "png"
|
||||
return upload_file
|
||||
|
||||
def test_message_end_with_no_files(self, mock_pipeline):
|
||||
"""Test that files array is None when no MessageFile records exist."""
|
||||
# Arrange
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is None
|
||||
assert result.id == mock_pipeline._message_id
|
||||
assert result.metadata == {"test": "metadata"}
|
||||
|
||||
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
|
||||
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local]
|
||||
|
||||
# Second query: UploadFile (batch query to avoid N+1)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [mock_upload_file]
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["related_id"] == mock_message_file_local.id
|
||||
assert file_dict["filename"] == "test_image.png"
|
||||
assert file_dict["mime_type"] == "image/png"
|
||||
assert file_dict["size"] == 1024
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["type"] == "image"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
|
||||
assert "https://example.com/signed-url" in file_dict["url"]
|
||||
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
|
||||
assert file_dict["remote_url"] == ""
|
||||
|
||||
# Verify database queries
|
||||
# Should be called twice: once for MessageFile, once for UploadFile
|
||||
assert mock_session.scalars.call_count == 2
|
||||
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
|
||||
|
||||
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
|
||||
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
|
||||
# Arrange
|
||||
mock_message_file_remote.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_remote]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["related_id"] == mock_message_file_remote.id
|
||||
assert file_dict["filename"] == "image.jpg"
|
||||
assert file_dict["url"] == "https://example.com/image.jpg"
|
||||
assert file_dict["extension"] == ".jpg"
|
||||
assert file_dict["type"] == "image"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
|
||||
assert file_dict["remote_url"] == "https://example.com/image.jpg"
|
||||
assert file_dict["upload_file_id"] == mock_message_file_remote.id
|
||||
|
||||
# Verify only one query for message_files is made
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
|
||||
# Arrange
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "https://example.com/tool_file.png"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["url"] == "https://example.com/tool_file.png"
|
||||
assert file_dict["filename"] == "tool_file.png"
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
|
||||
|
||||
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that files array is populated correctly for TOOL_FILE with local path."""
|
||||
# Arrange
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "tool_file_123.png"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
|
||||
assert file_dict["filename"] == "tool_file_123.png"
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
|
||||
|
||||
# Verify tool file signing was called
|
||||
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
|
||||
|
||||
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "tool_file_abc.verylongextension"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
mock_sign_tool.return_value = "https://example.com/signed.bin"
|
||||
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
assert result.files is not None
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["extension"] == ".bin"
|
||||
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
|
||||
|
||||
def test_message_end_with_multiple_files(
|
||||
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
|
||||
):
|
||||
"""Test that files array contains all MessageFile records when multiple exist."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
mock_message_file_remote.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
|
||||
|
||||
# Second query: UploadFile (batch query to avoid N+1)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [mock_upload_file]
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 2
|
||||
|
||||
# Verify both files are present
|
||||
file_ids = [f["related_id"] for f in result.files]
|
||||
assert mock_message_file_local.id in file_ids
|
||||
assert mock_message_file_remote.id in file_ids
|
||||
|
||||
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
|
||||
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local]
|
||||
|
||||
# Second query: UploadFile (batch query) - returns empty list (not found)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [] # UploadFile not found
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert "https://example.com/fallback-url" in file_dict["url"]
|
||||
# Verify fallback URL was generated using upload_file_id from message_file
|
||||
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))
|
||||
@ -0,0 +1,84 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = str(uuid4())
|
||||
user.current_tenant_id = str(uuid4())
|
||||
|
||||
repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=real_session_factory,
|
||||
user=user,
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = False
|
||||
repository._session_factory = MagicMock(return_value=session_context)
|
||||
return repository
|
||||
|
||||
|
||||
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
|
||||
return WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "hello"},
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist():
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists():
|
||||
session = MagicMock()
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
|
||||
execution_id = str(uuid4())
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repository._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
session.get.return_value = existing_run
|
||||
|
||||
execution = _build_execution(
|
||||
execution_id=execution_id,
|
||||
started_at=datetime(2026, 1, 1, 12, 30, 0),
|
||||
)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
@ -4,8 +4,10 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||
from dify_graph.variables.variables import StringVariable
|
||||
|
||||
|
||||
class StubCoordinator:
|
||||
@ -278,3 +280,17 @@ class TestGraphRuntimeState:
|
||||
assert restored_execution.started is True
|
||||
|
||||
assert new_stub.state == "configured"
|
||||
|
||||
def test_snapshot_restore_preserves_updated_conversation_variable(self):
|
||||
variable_pool = VariablePool(
|
||||
conversation_variables=[StringVariable(name="session_name", value="before")],
|
||||
)
|
||||
variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after")
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
snapshot = state.dumps()
|
||||
restored = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name"))
|
||||
assert restored_value is not None
|
||||
assert restored_value.value == "after"
|
||||
|
||||
@ -4,15 +4,13 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.enums import UserFrom
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph.validation import GraphValidationError
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _build_iteration_graph(node_id: str) -> dict[str, Any]:
|
||||
@ -53,14 +51,14 @@ def _build_loop_graph(node_id: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory:
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
graph_init_params = build_test_graph_init_params(
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user