Merge branch 'main' into feature/smtp-oauth2-support

This commit is contained in:
-LAN- 2025-09-26 17:35:12 +08:00 committed by GitHub
commit 4a8ac18879
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
47 changed files with 922 additions and 826 deletions

View File

@ -15,10 +15,12 @@ jobs:
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
python-version: "3.11"
- run: |
cd api
uv sync --dev
# fmt first to avoid line too long
uv run ruff format ..
# Fix lint errors
uv run ruff check --fix .
# Format code

View File

@ -38,7 +38,7 @@ uv run --directory api basedpyright # Type checking
```bash
cd web
pnpm lint # Run ESLint
pnpm eslint-fix # Fix ESLint issues
pnpm lint:fix # Fix ESLint issues
pnpm test # Run Jest tests
```

View File

@ -421,6 +421,9 @@ SSRF_DEFAULT_TIME_OUT=5
SSRF_DEFAULT_CONNECT_TIME_OUT=5
SSRF_DEFAULT_READ_TIME_OUT=5
SSRF_DEFAULT_WRITE_TIME_OUT=5
SSRF_POOL_MAX_CONNECTIONS=100
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20
SSRF_POOL_KEEPALIVE_EXPIRY=5.0
BATCH_UPLOAD_LIMIT=10
KEYWORD_DATA_SOURCE_TYPE=database
@ -431,6 +434,10 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
CODE_EXECUTION_API_KEY=dify-sandbox
CODE_EXECUTION_SSL_VERIFY=True
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_STRING_LENGTH=80000
@ -474,7 +481,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# GraphEngine Worker Pool Configuration

View File

@ -113,6 +113,21 @@ class CodeExecutionSandboxConfig(BaseSettings):
default=10.0,
)
CODE_EXECUTION_POOL_MAX_CONNECTIONS: PositiveInt = Field(
description="Maximum number of concurrent connections for the code execution HTTP client",
default=100,
)
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
description="Maximum number of persistent keep-alive connections for the code execution HTTP client",
default=20,
)
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
description="Keep-alive expiry in seconds for idle connections (set to None to disable)",
default=5.0,
)
CODE_MAX_NUMBER: PositiveInt = Field(
description="Maximum allowed numeric value in code execution",
default=9223372036854775807,
@ -153,6 +168,11 @@ class CodeExecutionSandboxConfig(BaseSettings):
default=1000,
)
CODE_EXECUTION_SSL_VERIFY: bool = Field(
description="Enable or disable SSL verification for code execution requests",
default=True,
)
class PluginConfig(BaseSettings):
"""
@ -404,6 +424,21 @@ class HttpConfig(BaseSettings):
default=5,
)
SSRF_POOL_MAX_CONNECTIONS: PositiveInt = Field(
description="Maximum number of concurrent connections for the SSRF HTTP client",
default=100,
)
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
description="Maximum number of persistent keep-alive connections for the SSRF HTTP client",
default=20,
)
SSRF_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
description="Keep-alive expiry in seconds for idle SSRF connections (set to None to disable)",
default=5.0,
)
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
" when the app is behind a single trusted reverse proxy.",
@ -542,11 +577,6 @@ class WorkflowConfig(BaseSettings):
default=5,
)
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
description="Maximum allowed depth for nested parallel executions",
default=3,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
default=200 * 1024,

View File

@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from configs import dify_config
from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
@ -797,24 +796,6 @@ class ConvertToWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
class WorkflowConfigApi(Resource):
"""Resource for workflow configuration."""
@api.doc("get_workflow_config")
@api.doc(description="Get workflow configuration")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Workflow configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@api.doc("get_all_published_workflows")

View File

@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from configs import dify_config
from controllers.console import api
from controllers.console.app.error import (
ConversationCompletedError,
@ -609,18 +608,6 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
class RagPipelineConfigApi(Resource):
"""Resource for rag pipeline configuration."""
@setup_required
@login_required
@account_initialization_required
def get(self, pipeline_id):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
class PublishedAllRagPipelineApi(Resource):
@setup_required
@login_required
@ -985,10 +972,6 @@ api.add_resource(
DraftRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
)
api.add_resource(
RagPipelineConfigApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/config",
)
api.add_resource(
DraftRagPipelineRunApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",

View File

@ -24,20 +24,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
try:
with Session(db.engine) as session:
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID.value
user_model = None
user_model = (
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
if is_anonymous:
user_model = (
session.query(EndUser)
.where(
@ -46,11 +40,21 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
)
.first()
)
else:
user_model = (
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID.value,
is_anonymous=is_anonymous,
session_id=user_id,
)
session.add(user_model)

View File

@ -30,7 +30,6 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from models.model import EndUser
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService
@ -311,8 +310,6 @@ class DocumentAddByFileApi(DatasetApiResource):
if not file.filename:
raise FilenameNotExistsError
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_file(
@ -406,9 +403,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not current_user:
raise ValueError("current_user is required")
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,

View File

@ -551,7 +551,7 @@ class AdvancedChatAppGenerateTaskPipeline:
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)

View File

@ -4,7 +4,7 @@ from enum import StrEnum
from threading import Lock
from typing import Any
from httpx import Timeout, post
import httpx
from pydantic import BaseModel
from yarl import URL
@ -13,9 +13,17 @@ from core.helper.code_executor.javascript.javascript_transformer import NodeJsTe
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
from core.helper.code_executor.template_transformer import TemplateTransformer
from core.helper.http_client_pooling import get_pooled_http_client
logger = logging.getLogger(__name__)
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
CODE_EXECUTION_SSL_VERIFY = dify_config.CODE_EXECUTION_SSL_VERIFY
_CODE_EXECUTOR_CLIENT_LIMITS = httpx.Limits(
max_connections=dify_config.CODE_EXECUTION_POOL_MAX_CONNECTIONS,
max_keepalive_connections=dify_config.CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS,
keepalive_expiry=dify_config.CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY,
)
_CODE_EXECUTOR_CLIENT_KEY = "code_executor:http_client"
class CodeExecutionError(Exception):
@ -38,6 +46,13 @@ class CodeLanguage(StrEnum):
JAVASCRIPT = "javascript"
def _build_code_executor_client() -> httpx.Client:
return httpx.Client(
verify=CODE_EXECUTION_SSL_VERIFY,
limits=_CODE_EXECUTOR_CLIENT_LIMITS,
)
class CodeExecutor:
dependencies_cache: dict[str, str] = {}
dependencies_cache_lock = Lock()
@ -76,17 +91,21 @@ class CodeExecutor:
"enable_network": True,
}
timeout = httpx.Timeout(
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
pool=None,
)
client = get_pooled_http_client(_CODE_EXECUTOR_CLIENT_KEY, _build_code_executor_client)
try:
response = post(
response = client.post(
str(url),
json=data,
headers=headers,
timeout=Timeout(
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
pool=None,
),
timeout=timeout,
)
if response.status_code == 503:
raise CodeExecutionError("Code execution service is unavailable")
@ -106,8 +125,8 @@ class CodeExecutor:
try:
response_data = response.json()
except:
raise CodeExecutionError("Failed to parse response")
except Exception as e:
raise CodeExecutionError("Failed to parse response") from e
if (code := response_data.get("code")) != 0:
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")

View File

@ -0,0 +1,59 @@
"""HTTP client pooling utilities."""
from __future__ import annotations
import atexit
import threading
from collections.abc import Callable
import httpx
ClientBuilder = Callable[[], httpx.Client]
class HttpClientPoolFactory:
"""Thread-safe factory that maintains reusable HTTP client instances."""
def __init__(self) -> None:
self._clients: dict[str, httpx.Client] = {}
self._lock = threading.Lock()
def get_or_create(self, key: str, builder: ClientBuilder) -> httpx.Client:
"""Return a pooled client associated with ``key`` creating it on demand."""
client = self._clients.get(key)
if client is not None:
return client
with self._lock:
client = self._clients.get(key)
if client is None:
client = builder()
self._clients[key] = client
return client
def close_all(self) -> None:
"""Close all pooled clients and clear the pool."""
with self._lock:
for client in self._clients.values():
client.close()
self._clients.clear()
_factory = HttpClientPoolFactory()
def get_pooled_http_client(key: str, builder: ClientBuilder) -> httpx.Client:
"""Return a pooled client for the given ``key`` using ``builder`` when missing."""
return _factory.get_or_create(key, builder)
def close_all_pooled_clients() -> None:
"""Close every client created through the pooling factory."""
_factory.close_all()
def _register_shutdown_hook() -> None:
atexit.register(close_all_pooled_clients)
_register_shutdown_hook()

View File

@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:

View File

@ -8,27 +8,23 @@ import time
import httpx
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
logger = logging.getLogger(__name__)
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True
try:
config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(config_value).lower()
if http_request_node_ssl_verify_lower == "true":
http_request_node_ssl_verify = True
elif http_request_node_ssl_verify_lower == "false":
http_request_node_ssl_verify = False
else:
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
except NameError:
http_request_node_ssl_verify = True
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
_SSRF_CLIENT_LIMITS = httpx.Limits(
max_connections=dify_config.SSRF_POOL_MAX_CONNECTIONS,
max_keepalive_connections=dify_config.SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS,
keepalive_expiry=dify_config.SSRF_POOL_KEEPALIVE_EXPIRY,
)
class MaxRetriesExceededError(ValueError):
"""Raised when the maximum number of retries is exceeded."""
@ -36,6 +32,45 @@ class MaxRetriesExceededError(ValueError):
pass
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTP_URL,
),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
),
}
def _build_ssrf_client(verify: bool) -> httpx.Client:
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
proxy=dify_config.SSRF_PROXY_ALL_URL,
verify=verify,
limits=_SSRF_CLIENT_LIMITS,
)
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
return httpx.Client(
mounts=_create_proxy_mounts(),
verify=verify,
limits=_SSRF_CLIENT_LIMITS,
)
return httpx.Client(verify=verify, limits=_SSRF_CLIENT_LIMITS)
def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
if not isinstance(ssl_verify_enabled, bool):
raise ValueError("SSRF client verify flag must be a boolean")
return get_pooled_http_client(
_SSL_VERIFIED_POOL_KEY if ssl_verify_enabled else _SSL_UNVERIFIED_POOL_KEY,
lambda: _build_ssrf_client(verify=ssl_verify_enabled),
)
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@ -50,33 +85,22 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
)
if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = http_request_node_ssl_verify
ssl_verify = kwargs.pop("ssl_verify")
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
retries = 0
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
}
with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
else:
with httpx.Client(verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
return response
else:
logger.warning(
"Received status code %s for URL %s which is in the force list", response.status_code, url
"Received status code %s for URL %s which is in the force list",
response.status_code,
url,
)
except httpx.RequestError as e:

View File

@ -28,7 +28,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.node_events import AgentLogEvent
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import App, Message, WorkflowNodeExecutionModel
@ -462,19 +461,18 @@ class LLMGenerator:
)
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG, [])
if not raw_agent_log:
return []
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
def dict_of_event(event: AgentLogEvent):
return {
"status": event.status,
"error": event.error,
"data": event.data,
return [
{
"status": event["status"],
"error": event["error"],
"data": event["data"],
}
return [dict_of_event(event) for event in parsed]
for event in raw_agent_log
]
inputs = last_run.load_full_inputs(session, storage)
last_run_dict = {

View File

@ -1,38 +1,28 @@
import json
import logging
from collections.abc import Sequence
from urllib.parse import urljoin
from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient,
build_endpoint,
convert_datetime_to_nanoseconds,
convert_to_span_id,
convert_to_trace_id,
create_link,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
GEN_AI_PROMPT_TEMPLATE_VARIABLE,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_SYSTEM,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
@ -40,6 +30,15 @@ from core.ops.aliyun_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.aliyun_trace.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
extract_retrieval_documents,
get_user_id_from_message_data,
get_workflow_node_status,
serialize_json_data,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import AliyunConfig
from core.ops.entities.trace_entity import (
@ -52,12 +51,11 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.rag.models.document import Document
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
from models import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@ -68,8 +66,7 @@ class AliyunDataTrace(BaseTraceInstance):
aliyun_config: AliyunConfig,
):
super().__init__(aliyun_config)
base_url = aliyun_config.endpoint.rstrip("/")
endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
endpoint = build_endpoint(aliyun_config.endpoint, aliyun_config.license_key)
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
def trace(self, trace_info: BaseTraceInfo):
@ -95,423 +92,422 @@ class AliyunDataTrace(BaseTraceInstance):
try:
return self.trace_client.get_project_url()
except Exception as e:
logger.info("Aliyun get run url failed: %s", str(e), exc_info=True)
raise ValueError(f"Aliyun get run url failed: {str(e)}")
logger.info("Aliyun get project url failed: %s", str(e), exc_info=True)
raise ValueError(f"Aliyun get project url failed: {str(e)}")
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
links = []
if trace_info.trace_id:
links.append(create_link(trace_id_str=trace_info.trace_id))
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info, links)
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(trace_info.workflow_run_id),
workflow_span_id=convert_to_span_id(trace_info.workflow_run_id, "workflow"),
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=str(trace_info.metadata.get("user_id") or ""),
links=create_links_from_trace_id(trace_info.trace_id),
)
self.add_workflow_span(trace_info, trace_metadata)
workflow_node_executions = self.get_workflow_node_executions(trace_info)
for node_execution in workflow_node_executions:
node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
node_span = self.build_workflow_node_span(node_execution, trace_info, trace_metadata)
self.trace_client.add_span(node_span)
def message_trace(self, trace_info: MessageTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
message_id = trace_info.message_id
user_id = get_user_id_from_message_data(message_data)
status = create_status_from_error(trace_info.error)
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: EndUser | None = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(message_id),
workflow_span_id=0,
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=user_id,
links=create_links_from_trace_id(trace_info.trace_id),
)
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
links.append(create_link(trace_id_str=trace_info.trace_id))
inputs_json = serialize_json_data(trace_info.inputs)
outputs_str = str(trace_info.outputs)
message_span_id = convert_to_span_id(message_id, "message")
message_span = SpanData(
trace_id=trace_id,
trace_id=trace_metadata.trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.outputs),
},
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_str,
),
status=status,
links=links,
links=trace_metadata.links,
)
self.trace_client.add_span(message_span)
app_model_config = getattr(trace_info.message_data, "app_model_config", {})
app_model_config = getattr(message_data, "app_model_config", {})
pre_prompt = getattr(app_model_config, "pre_prompt", "")
inputs_data = getattr(trace_info.message_data, "inputs", {})
inputs_data = getattr(message_data, "inputs", {})
llm_span = SpanData(
trace_id=trace_id,
trace_id=trace_metadata.trace_id,
parent_span_id=message_span_id,
span_id=convert_to_span_id(message_id, "llm"),
name="llm",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.LLM,
inputs=inputs_json,
outputs=outputs_str,
),
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "",
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "",
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
GEN_AI_PROMPT_TEMPLATE_VARIABLE: serialize_json_data(inputs_data),
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
GEN_AI_COMPLETION: str(trace_info.outputs),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.outputs),
GEN_AI_PROMPT: inputs_json,
GEN_AI_COMPLETION: outputs_str,
},
status=status,
links=trace_metadata.links,
)
self.trace_client.add_span(llm_span)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
links.append(create_link(trace_id_str=trace_info.trace_id))
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(message_id),
workflow_span_id=0,
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=str(trace_info.metadata.get("user_id") or ""),
links=create_links_from_trace_id(trace_info.trace_id),
)
documents_data = extract_retrieval_documents(trace_info.documents)
documents_json = serialize_json_data(documents_data)
inputs_str = str(trace_info.inputs)
dataset_retrieval_span = SpanData(
trace_id=trace_id,
trace_id=trace_metadata.trace_id,
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name="dataset_retrieval",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: str(trace_info.inputs),
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
INPUT_VALUE: str(trace_info.inputs),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.RETRIEVER,
inputs=inputs_str,
outputs=documents_json,
),
RETRIEVAL_QUERY: inputs_str,
RETRIEVAL_DOCUMENT: documents_json,
},
links=links,
links=trace_metadata.links,
)
self.trace_client.add_span(dataset_retrieval_span)
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
status = create_status_from_error(trace_info.error)
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(message_id),
workflow_span_id=0,
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=str(trace_info.metadata.get("user_id") or ""),
links=create_links_from_trace_id(trace_info.trace_id),
)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
links.append(create_link(trace_id_str=trace_info.trace_id))
tool_config_json = serialize_json_data(trace_info.tool_config)
tool_inputs_json = serialize_json_data(trace_info.tool_inputs)
inputs_json = serialize_json_data(trace_info.inputs)
tool_span = SpanData(
trace_id=trace_id,
trace_id=trace_metadata.trace_id,
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name=trace_info.tool_name,
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.TOOL,
inputs=inputs_json,
outputs=str(trace_info.tool_outputs),
),
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.tool_outputs),
TOOL_DESCRIPTION: tool_config_json,
TOOL_PARAMETERS: tool_inputs_json,
},
status=status,
links=links,
links=trace_metadata.links,
)
self.trace_client.add_span(tool_span)
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
return workflow_node_executions
return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id)
def build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
):
try:
if node_execution.node_type == NodeType.LLM:
node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == NodeType.TOOL:
node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
else:
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)
return node_span
except Exception as e:
logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True)
return None
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
span_status: Status = Status(StatusCode.UNSET)
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
span_status = Status(StatusCode.OK)
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
span_status = Status(StatusCode.ERROR, str(node_execution.error))
return span_status
def build_workflow_task_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
) -> SpanData:
inputs_json = serialize_json_data(node_execution.inputs)
outputs_json = serialize_json_data(node_execution.outputs)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
trace_id=trace_metadata.trace_id,
parent_span_id=trace_metadata.workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.TASK,
inputs=inputs_json,
outputs=outputs_json,
),
status=get_workflow_node_status(node_execution),
links=trace_metadata.links,
)
def build_workflow_tool_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
) -> SpanData:
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
inputs_json = serialize_json_data(node_execution.inputs or {})
outputs_json = serialize_json_data(node_execution.outputs)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
trace_id=trace_metadata.trace_id,
parent_span_id=trace_metadata.workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.TOOL,
inputs=inputs_json,
outputs=outputs_json,
),
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
TOOL_DESCRIPTION: serialize_json_data(tool_des),
TOOL_PARAMETERS: inputs_json,
},
status=self.get_workflow_node_status(node_execution),
status=get_workflow_node_status(node_execution),
links=trace_metadata.links,
)
def build_workflow_retrieval_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
) -> SpanData:
input_value = ""
if node_execution.inputs:
input_value = str(node_execution.inputs.get("query", ""))
output_value = ""
if node_execution.outputs:
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
input_value = str(node_execution.inputs.get("query", "")) if node_execution.inputs else ""
output_value = serialize_json_data(node_execution.outputs.get("result", [])) if node_execution.outputs else ""
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
trace_id=trace_metadata.trace_id,
parent_span_id=trace_metadata.workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.RETRIEVER,
inputs=input_value,
outputs=output_value,
),
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: output_value,
INPUT_VALUE: input_value,
OUTPUT_VALUE: output_value,
},
status=self.get_workflow_node_status(node_execution),
status=get_workflow_node_status(node_execution),
links=trace_metadata.links,
)
def build_workflow_llm_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
self, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution, trace_metadata: TraceMetadata
) -> SpanData:
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompts_json = serialize_json_data(process_data.get("prompts", []))
text_output = str(outputs.get("text", ""))
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
trace_id=trace_metadata.trace_id,
parent_span_id=trace_metadata.workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.LLM,
inputs=prompts_json,
outputs=text_output,
),
GEN_AI_MODEL_NAME: process_data.get("model_name") or "",
GEN_AI_SYSTEM: process_data.get("model_provider") or "",
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_PROMPT: prompts_json,
GEN_AI_COMPLETION: text_output,
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason") or "",
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
},
status=self.get_workflow_node_status(node_execution),
status=get_workflow_node_status(node_execution),
links=trace_metadata.links,
)
def add_workflow_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link]
):
def add_workflow_span(self, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata):
message_span_id = None
if trace_info.message_id:
message_span_id = convert_to_span_id(trace_info.message_id, "message")
user_id = trace_info.metadata.get("user_id")
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
if message_span_id: # chatflow
status = create_status_from_error(trace_info.error)
inputs_json = serialize_json_data(trace_info.workflow_run_inputs)
outputs_json = serialize_json_data(trace_info.workflow_run_outputs)
if message_span_id:
message_span = SpanData(
trace_id=trace_id,
trace_id=trace_metadata.trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query") or "",
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=trace_info.workflow_run_inputs.get("sys.query") or "",
outputs=outputs_json,
),
status=status,
links=links,
links=trace_metadata.links,
)
self.trace_client.add_span(message_span)
workflow_span = SpanData(
trace_id=trace_id,
trace_id=trace_metadata.trace_id,
parent_span_id=message_span_id,
span_id=workflow_span_id,
span_id=trace_metadata.workflow_span_id,
name="workflow",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
attributes=create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.CHAIN,
inputs=inputs_json,
outputs=outputs_json,
),
status=status,
links=links,
links=trace_metadata.links,
)
self.trace_client.add_span(workflow_span)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_id = trace_info.message_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
status = create_status_from_error(trace_info.error)
trace_id = convert_to_trace_id(message_id)
links = []
if trace_info.trace_id:
links.append(create_link(trace_id_str=trace_info.trace_id))
trace_metadata = TraceMetadata(
trace_id=convert_to_trace_id(message_id),
workflow_span_id=0,
session_id=trace_info.metadata.get("conversation_id") or "",
user_id=str(trace_info.metadata.get("user_id") or ""),
links=create_links_from_trace_id(trace_info.trace_id),
)
inputs_json = serialize_json_data(trace_info.inputs)
suggested_question_json = serialize_json_data(trace_info.suggested_question)
suggested_question_span = SpanData(
trace_id=trace_id,
trace_id=trace_metadata.trace_id,
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=convert_to_span_id(message_id, "suggested_question"),
name="suggested_question",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
**create_common_span_attributes(
session_id=trace_metadata.session_id,
user_id=trace_metadata.user_id,
span_kind=GenAISpanKind.LLM,
inputs=inputs_json,
outputs=suggested_question_json,
),
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "",
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "",
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
GEN_AI_PROMPT: inputs_json,
GEN_AI_COMPLETION: suggested_question_json,
},
status=status,
links=links,
links=trace_metadata.links,
)
self.trace_client.add_span(suggested_question_span)
def extract_retrieval_documents(documents: list[Document]):
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data

View File

@ -7,6 +7,8 @@ import uuid
from collections import deque
from collections.abc import Sequence
from datetime import datetime
from typing import Final
from urllib.parse import urljoin
import httpx
from opentelemetry import trace as trace_api
@ -20,8 +22,12 @@ from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
DEFAULT_TIMEOUT: Final[int] = 5
DEFAULT_MAX_QUEUE_SIZE: Final[int] = 1000
DEFAULT_SCHEDULE_DELAY_SEC: Final[int] = 5
DEFAULT_MAX_EXPORT_BATCH_SIZE: Final[int] = 50
logger = logging.getLogger(__name__)
@ -31,9 +37,9 @@ class TraceClient:
self,
service_name: str,
endpoint: str,
max_queue_size: int = 1000,
schedule_delay_sec: int = 5,
max_export_batch_size: int = 50,
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
schedule_delay_sec: int = DEFAULT_SCHEDULE_DELAY_SEC,
max_export_batch_size: int = DEFAULT_MAX_EXPORT_BATCH_SIZE,
):
self.endpoint = endpoint
self.resource = Resource(
@ -63,9 +69,9 @@ class TraceClient:
def export(self, spans: Sequence[ReadableSpan]):
self.exporter.export(spans)
def api_check(self):
def api_check(self) -> bool:
try:
response = httpx.head(self.endpoint, timeout=5)
response = httpx.head(self.endpoint, timeout=DEFAULT_TIMEOUT)
if response.status_code == 405:
return True
else:
@ -75,12 +81,13 @@ class TraceClient:
logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
def get_project_url(self):
def get_project_url(self) -> str:
return "https://arms.console.aliyun.com/#/llm"
def add_span(self, span_data: SpanData):
def add_span(self, span_data: SpanData | None) -> None:
if span_data is None:
return
span: ReadableSpan = self.span_builder.build_span(span_data)
with self.condition:
if len(self.queue) == self.max_queue_size:
@ -92,14 +99,14 @@ class TraceClient:
if len(self.queue) >= self.max_export_batch_size:
self.condition.notify()
def _worker(self):
def _worker(self) -> None:
while not self.done:
with self.condition:
if len(self.queue) < self.max_export_batch_size and not self.done:
self.condition.wait(timeout=self.schedule_delay_sec)
self._export_batch()
def _export_batch(self):
def _export_batch(self) -> None:
spans_to_export: list[ReadableSpan] = []
with self.condition:
while len(spans_to_export) < self.max_export_batch_size and self.queue:
@ -111,7 +118,7 @@ class TraceClient:
except Exception as e:
logger.debug("Error exporting spans: %s", e)
def shutdown(self):
def shutdown(self) -> None:
with self.condition:
self.done = True
self.condition.notify_all()
@ -121,7 +128,7 @@ class TraceClient:
class SpanBuilder:
def __init__(self, resource):
def __init__(self, resource: Resource) -> None:
self.resource = resource
self.instrumentation_scope = InstrumentationScope(
__name__,
@ -167,8 +174,12 @@ class SpanBuilder:
def create_link(trace_id_str: str) -> Link:
placeholder_span_id = 0x0000000000000000
trace_id = int(trace_id_str, 16)
placeholder_span_id = INVALID_SPAN_ID
try:
trace_id = int(trace_id_str, 16)
except ValueError as e:
raise ValueError(f"Invalid trace ID format: {trace_id_str}") from e
span_context = SpanContext(
trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED)
)
@ -184,26 +195,29 @@ def generate_span_id() -> int:
def convert_to_trace_id(uuid_v4: str | None) -> int:
if uuid_v4 is None:
raise ValueError("UUID cannot be None")
try:
uuid_obj = uuid.UUID(uuid_v4)
return uuid_obj.int
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
except ValueError as e:
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
def convert_string_to_id(string: str | None) -> int:
if not string:
return generate_span_id()
hash_bytes = hashlib.sha256(string.encode("utf-8")).digest()
id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
return id
return int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int:
if uuid_v4 is None:
raise ValueError("UUID cannot be None")
try:
uuid_obj = uuid.UUID(uuid_v4)
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
except ValueError as e:
raise ValueError(f"Invalid UUID input: {uuid_v4}") from e
combined_key = f"{uuid_obj.hex}-{span_type}"
return convert_string_to_id(combined_key)
@ -212,5 +226,11 @@ def convert_datetime_to_nanoseconds(start_time_a: datetime | None) -> int | None
if start_time_a is None:
return None
timestamp_in_seconds = start_time_a.timestamp()
timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
return timestamp_in_nanoseconds
return int(timestamp_in_seconds * 1e9)
def build_endpoint(base_url: str, license_key: str) -> str:
if "log.aliyuncs.com" in base_url: # cms2.0 endpoint
return urljoin(base_url, f"adapt_{license_key}/api/v1/traces")
else: # xtrace endpoint
return urljoin(base_url, f"adapt_{license_key}/api/otlp/traces")

View File

@ -1,18 +1,33 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event, Status, StatusCode
from pydantic import BaseModel, Field
@dataclass
class TraceMetadata:
"""Metadata for trace operations, containing common attributes for all spans in a trace."""
trace_id: int
workflow_span_id: int
session_id: str
user_id: str
links: list[trace_api.Link]
class SpanData(BaseModel):
"""Data model for span information in Aliyun trace system."""
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: int | None = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
attributes: dict[str, Any] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")

View File

@ -1,56 +1,37 @@
from enum import StrEnum
from typing import Final
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
# Public attributes
GEN_AI_SESSION_ID: Final[str] = "gen_ai.session.id"
GEN_AI_USER_ID: Final[str] = "gen_ai.user.id"
GEN_AI_USER_NAME: Final[str] = "gen_ai.user.name"
GEN_AI_SPAN_KIND: Final[str] = "gen_ai.span.kind"
GEN_AI_FRAMEWORK: Final[str] = "gen_ai.framework"
GEN_AI_USER_ID = "gen_ai.user.id"
# Chain attributes
INPUT_VALUE: Final[str] = "input.value"
OUTPUT_VALUE: Final[str] = "output.value"
GEN_AI_USER_NAME = "gen_ai.user.name"
# Retriever attributes
RETRIEVAL_QUERY: Final[str] = "retrieval.query"
RETRIEVAL_DOCUMENT: Final[str] = "retrieval.document"
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
# LLM attributes
GEN_AI_MODEL_NAME: Final[str] = "gen_ai.model_name"
GEN_AI_SYSTEM: Final[str] = "gen_ai.system"
GEN_AI_USAGE_INPUT_TOKENS: Final[str] = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS: Final[str] = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS: Final[str] = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: Final[str] = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE: Final[str] = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT: Final[str] = "gen_ai.prompt"
GEN_AI_COMPLETION: Final[str] = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON: Final[str] = "gen_ai.response.finish_reason"
GEN_AI_FRAMEWORK = "gen_ai.framework"
# Chain
INPUT_VALUE = "input.value"
OUTPUT_VALUE = "output.value"
# Retriever
RETRIEVAL_QUERY = "retrieval.query"
RETRIEVAL_DOCUMENT = "retrieval.document"
# LLM
GEN_AI_MODEL_NAME = "gen_ai.model_name"
GEN_AI_SYSTEM = "gen_ai.system"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Tool
TOOL_NAME = "tool.name"
TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
# Tool attributes
TOOL_NAME: Final[str] = "tool.name"
TOOL_DESCRIPTION: Final[str] = "tool.description"
TOOL_PARAMETERS: Final[str] = "tool.parameters"
class GenAISpanKind(StrEnum):

View File

@ -0,0 +1,95 @@
import json
from typing import Any
from opentelemetry.trace import Link, Status, StatusCode
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
GenAISpanKind,
)
from core.rag.models.document import Document
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import WorkflowNodeExecutionStatus
from extensions.ext_database import db
from models import EndUser
# Constants
DEFAULT_JSON_ENSURE_ASCII = False
DEFAULT_FRAMEWORK_NAME = "dify"
def get_user_id_from_message_data(message_data) -> str:
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: EndUser | None = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
return user_id
def create_status_from_error(error: str | None) -> Status:
if error:
return Status(StatusCode.ERROR, error)
return Status(StatusCode.OK)
def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
return Status(StatusCode.OK)
if node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
return Status(StatusCode.ERROR, str(node_execution.error))
return Status(StatusCode.UNSET)
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
from core.ops.aliyun_trace.data_exporter.traceclient import create_link
links = []
if trace_id:
links.append(create_link(trace_id_str=trace_id))
return links
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data
def serialize_json_data(data: Any, ensure_ascii: bool = DEFAULT_JSON_ENSURE_ASCII) -> str:
return json.dumps(data, ensure_ascii=ensure_ascii)
def create_common_span_attributes(
session_id: str = "",
user_id: str = "",
span_kind: str = GenAISpanKind.CHAIN,
framework: str = DEFAULT_FRAMEWORK_NAME,
inputs: str = "",
outputs: str = "",
) -> dict[str, Any]:
return {
GEN_AI_SESSION_ID: session_id,
GEN_AI_USER_ID: user_id,
GEN_AI_SPAN_KIND: span_kind,
GEN_AI_FRAMEWORK: framework,
INPUT_VALUE: inputs,
OUTPUT_VALUE: outputs,
}

View File

@ -191,7 +191,8 @@ class AliyunConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
# aliyun uses two URL formats, which may include a URL path
return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
OPS_FILE_PATH = "ops_trace/"

View File

@ -1,9 +1,11 @@
import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.variables import IntegerVariable, NoneSegment
@ -35,6 +37,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@ -239,6 +242,8 @@ class IterationNode(Node):
self._execute_single_iteration_parallel,
index=index,
item=item,
flask_app=current_app._get_current_object(), # type: ignore
context_vars=contextvars.copy_context(),
)
future_to_index[future] = index
@ -281,26 +286,29 @@ class IterationNode(Node):
self,
index: int,
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
"""Execute a single iteration in parallel mode and return results."""
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
graph_engine = self._create_graph_engine(index, item)
graph_engine = self._create_graph_engine(index, item)
# Collect events instead of yielding them directly
for event in self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs_temp,
graph_engine=graph_engine,
):
events.append(event)
# Collect events instead of yielding them directly
for event in self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs_temp,
graph_engine=graph_engine,
):
events.append(event)
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
# Get the output value from the temporary outputs list
output_value = outputs_temp[0] if outputs_temp else None
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
def _handle_iteration_success(
self,

View File

@ -8,6 +8,7 @@ from typing import Any
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.http import parse_options_header
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
@ -247,6 +248,25 @@ def _build_from_remote_url(
)
def _extract_filename(url_path: str, content_disposition: str | None) -> str | None:
filename = None
# Try to extract from Content-Disposition header first
if content_disposition:
_, params = parse_options_header(content_disposition)
# RFC 5987 https://datatracker.ietf.org/doc/html/rfc5987: filename* takes precedence over filename
filename = params.get("filename*") or params.get("filename")
# Fallback to URL path if no filename from header
if not filename:
filename = os.path.basename(url_path)
return filename or None
def _guess_mime_type(filename: str) -> str:
"""Guess MIME type from filename, returning empty string if None."""
guessed_mime, _ = mimetypes.guess_type(filename)
return guessed_mime or ""
def _get_remote_file_info(url: str):
file_size = -1
parsed_url = urllib.parse.urlparse(url)
@ -254,23 +274,26 @@ def _get_remote_file_info(url: str):
filename = os.path.basename(url_path)
# Initialize mime_type from filename as fallback
mime_type, _ = mimetypes.guess_type(filename)
if mime_type is None:
mime_type = ""
mime_type = _guess_mime_type(filename)
resp = ssrf_proxy.head(url, follow_redirects=True)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"'))
# Re-guess mime_type from updated filename
mime_type, _ = mimetypes.guess_type(filename)
if mime_type is None:
mime_type = ""
content_disposition = resp.headers.get("Content-Disposition")
extracted_filename = _extract_filename(url_path, content_disposition)
if extracted_filename:
filename = extracted_filename
mime_type = _guess_mime_type(filename)
file_size = int(resp.headers.get("Content-Length", file_size))
# Fallback to Content-Type header if mime_type is still empty
if not mime_type:
mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip()
if not filename:
extension = mimetypes.guess_extension(mime_type) or ".bin"
filename = f"{uuid.uuid4().hex}{extension}"
if not mime_type:
mime_type = _guess_mime_type(filename)
return mime_type, filename, file_size

View File

@ -167,7 +167,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# App configuration

View File

@ -40,8 +40,6 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
# annotated field with configured value
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
# values from pyproject.toml
assert Version(config.project.version) >= Version("1.0.0")

View File

@ -329,20 +329,20 @@ class TestAliyunConfig:
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
def test_endpoint_validation_with_path(self):
"""Test endpoint validation normalizes URL by removing path"""
"""Test endpoint validation preserves path for Aliyun endpoints"""
config = AliyunConfig(
license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
)
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
def test_endpoint_validation_invalid_scheme(self):
"""Test endpoint validation rejects invalid schemes"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_endpoint_validation_no_scheme(self):
"""Test endpoint validation rejects URLs without scheme"""
with pytest.raises(ValidationError, match="URL scheme must be one of"):
with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
def test_license_key_required(self):
@ -350,6 +350,23 @@ class TestAliyunConfig:
with pytest.raises(ValidationError):
AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
def test_valid_endpoint_format_examples(self):
"""Test valid endpoint format examples from comments"""
valid_endpoints = [
# cms2.0 public endpoint
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
# cms2.0 intranet endpoint
"https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
# xtrace public endpoint
"http://tracing-cn-heyuan.arms.aliyuncs.com",
# xtrace intranet endpoint
"http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
]
for endpoint in valid_endpoints:
config = AliyunConfig(license_key="test_license", endpoint=endpoint)
assert config.endpoint == endpoint
class TestConfigIntegration:
"""Integration tests for configuration classes"""
@ -382,7 +399,7 @@ class TestConfigIntegration:
assert arize_config.endpoint == "https://arize.com"
assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com"
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
def test_project_default_values(self):
"""Test that project default values are set correctly"""

View File

@ -0,0 +1,115 @@
import re
import pytest
from factories.file_factory import _get_remote_file_info
class _FakeResponse:
def __init__(self, status_code: int, headers: dict[str, str]):
self.status_code = status_code
self.headers = headers
def _mock_head(monkeypatch: pytest.MonkeyPatch, headers: dict[str, str], status_code: int = 200):
def _fake_head(url: str, follow_redirects: bool = True):
return _FakeResponse(status_code=status_code, headers=headers)
monkeypatch.setattr("factories.file_factory.ssrf_proxy.head", _fake_head)
class TestGetRemoteFileInfo:
"""Tests for _get_remote_file_info focusing on filename extraction rules."""
def test_inline_no_filename(self, monkeypatch: pytest.MonkeyPatch):
_mock_head(
monkeypatch,
{
"Content-Disposition": "inline",
"Content-Type": "application/pdf",
"Content-Length": "123",
},
)
mime_type, filename, size = _get_remote_file_info("http://example.com/some/path/file.pdf")
assert filename == "file.pdf"
assert mime_type == "application/pdf"
assert size == 123
def test_attachment_no_filename(self, monkeypatch: pytest.MonkeyPatch):
_mock_head(
monkeypatch,
{
"Content-Disposition": "attachment",
"Content-Type": "application/octet-stream",
"Content-Length": "456",
},
)
mime_type, filename, size = _get_remote_file_info("http://example.com/downloads/data.bin")
assert filename == "data.bin"
assert mime_type == "application/octet-stream"
assert size == 456
def test_attachment_quoted_space_filename(self, monkeypatch: pytest.MonkeyPatch):
_mock_head(
monkeypatch,
{
"Content-Disposition": 'attachment; filename="file name.jpg"',
"Content-Type": "image/jpeg",
"Content-Length": "789",
},
)
mime_type, filename, size = _get_remote_file_info("http://example.com/ignored")
assert filename == "file name.jpg"
assert mime_type == "image/jpeg"
assert size == 789
def test_attachment_filename_star_percent20(self, monkeypatch: pytest.MonkeyPatch):
_mock_head(
monkeypatch,
{
"Content-Disposition": "attachment; filename*=UTF-8''file%20name.jpg",
"Content-Type": "image/jpeg",
},
)
mime_type, filename, _ = _get_remote_file_info("http://example.com/ignored")
assert filename == "file name.jpg"
assert mime_type == "image/jpeg"
def test_attachment_filename_star_chinese(self, monkeypatch: pytest.MonkeyPatch):
_mock_head(
monkeypatch,
{
"Content-Disposition": "attachment; filename*=UTF-8''%E6%B5%8B%E8%AF%95%E6%96%87%E4%BB%B6.jpg",
"Content-Type": "image/jpeg",
},
)
mime_type, filename, _ = _get_remote_file_info("http://example.com/ignored")
assert filename == "测试文件.jpg"
assert mime_type == "image/jpeg"
def test_filename_from_url_when_no_header(self, monkeypatch: pytest.MonkeyPatch):
_mock_head(
monkeypatch,
{
# No Content-Disposition
"Content-Type": "text/plain",
"Content-Length": "12",
},
)
mime_type, filename, size = _get_remote_file_info("http://example.com/static/file.txt")
assert filename == "file.txt"
assert mime_type == "text/plain"
assert size == 12
def test_no_filename_in_url_or_header_generates_uuid_bin(self, monkeypatch: pytest.MonkeyPatch):
_mock_head(
monkeypatch,
{
"Content-Disposition": "inline",
"Content-Type": "application/octet-stream",
},
)
mime_type, filename, _ = _get_remote_file_info("http://example.com/test/")
# Should generate a random hex filename with .bin extension
assert re.match(r"^[0-9a-f]{32}\.bin$", filename) is not None
assert mime_type == "application/octet-stream"

View File

@ -886,6 +886,10 @@ OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
# The sandbox service endpoint.
CODE_EXECUTION_ENDPOINT=http://sandbox:8194
CODE_EXECUTION_API_KEY=dify-sandbox
CODE_EXECUTION_SSL_VERIFY=True
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
CODE_MAX_NUMBER=9223372036854775807
CODE_MIN_NUMBER=-9223372036854775808
CODE_MAX_DEPTH=5
@ -904,7 +908,6 @@ WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
MAX_VARIABLE_SIZE=204800
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
WORKFLOW_FILE_UPLOAD_LIMIT=10
# GraphEngine Worker Pool Configuration
@ -1161,6 +1164,9 @@ SSRF_DEFAULT_TIME_OUT=5
SSRF_DEFAULT_CONNECT_TIME_OUT=5
SSRF_DEFAULT_READ_TIME_OUT=5
SSRF_DEFAULT_WRITE_TIME_OUT=5
SSRF_POOL_MAX_CONNECTIONS=100
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20
SSRF_POOL_KEEPALIVE_EXPIRY=5.0
# ------------------------------
# docker env var for specifying vector db type at startup

View File

@ -387,6 +387,10 @@ x-shared-env: &shared-api-worker-env
OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: ${OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES:-5}
CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194}
CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox}
CODE_EXECUTION_SSL_VERIFY: ${CODE_EXECUTION_SSL_VERIFY:-True}
CODE_EXECUTION_POOL_MAX_CONNECTIONS: ${CODE_EXECUTION_POOL_MAX_CONNECTIONS:-100}
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: ${CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS:-20}
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: ${CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY:-5.0}
CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807}
CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808}
CODE_MAX_DEPTH: ${CODE_MAX_DEPTH:-5}
@ -403,7 +407,6 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_MAX_EXECUTION_TIME: ${WORKFLOW_MAX_EXECUTION_TIME:-1200}
WORKFLOW_CALL_MAX_DEPTH: ${WORKFLOW_CALL_MAX_DEPTH:-5}
MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800}
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
GRAPH_ENGINE_MIN_WORKERS: ${GRAPH_ENGINE_MIN_WORKERS:-1}
GRAPH_ENGINE_MAX_WORKERS: ${GRAPH_ENGINE_MAX_WORKERS:-10}
@ -502,6 +505,9 @@ x-shared-env: &shared-api-worker-env
SSRF_DEFAULT_CONNECT_TIME_OUT: ${SSRF_DEFAULT_CONNECT_TIME_OUT:-5}
SSRF_DEFAULT_READ_TIME_OUT: ${SSRF_DEFAULT_READ_TIME_OUT:-5}
SSRF_DEFAULT_WRITE_TIME_OUT: ${SSRF_DEFAULT_WRITE_TIME_OUT:-5}
SSRF_POOL_MAX_CONNECTIONS: ${SSRF_POOL_MAX_CONNECTIONS:-100}
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: ${SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS:-20}
SSRF_POOL_KEEPALIVE_EXPIRY: ${SSRF_POOL_KEEPALIVE_EXPIRY:-5.0}
EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80}
EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443}
POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-}

View File

@ -1,6 +1,7 @@
import {
createContext,
useContext,
useEffect,
useRef,
} from 'react'
import {
@ -18,13 +19,11 @@ type Shape = {
export const createFileStore = (
value: FileEntity[] = [],
onChange?: (files: FileEntity[]) => void,
) => {
return create<Shape>(set => ({
files: value ? [...value] : [],
setFiles: (files) => {
set({ files })
onChange?.(files)
},
}))
}
@ -55,9 +54,35 @@ export const FileContextProvider = ({
onChange,
}: FileProviderProps) => {
const storeRef = useRef<FileStore | undefined>(undefined)
const onChangeRef = useRef<FileProviderProps['onChange']>(onChange)
const isSyncingRef = useRef(false)
if (!storeRef.current)
storeRef.current = createFileStore(value, onChange)
storeRef.current = createFileStore(value)
// keep latest onChange
useEffect(() => {
onChangeRef.current = onChange
}, [onChange])
// subscribe to store changes and call latest onChange
useEffect(() => {
const store = storeRef.current!
const unsubscribe = store.subscribe((state: Shape) => {
if (isSyncingRef.current) return
onChangeRef.current?.(state.files)
})
return unsubscribe
}, [])
// sync external value into internal store when value changes
useEffect(() => {
const store = storeRef.current!
const nextFiles = value ? [...value] : []
isSyncingRef.current = true
store.setState({ files: nextFiles })
isSyncingRef.current = false
}, [value])
return (
<FileContext.Provider value={storeRef.current}>

View File

@ -53,7 +53,7 @@ const SearchInput: FC<SearchInputProps> = ({
}}
onCompositionEnd={(e) => {
isComposing.current = false
onChange(e.data)
onChange(e.currentTarget.value)
}}
onFocus={() => setFocus(true)}
onBlur={() => setFocus(false)}

View File

@ -14,16 +14,6 @@ export const usePipelineConfig = () => {
const pipelineId = useStore(s => s.pipelineId)
const workflowStore = useWorkflowStore()
const handleUpdateWorkflowConfig = useCallback((config: Record<string, any>) => {
const { setWorkflowConfig } = workflowStore.getState()
setWorkflowConfig(config)
}, [workflowStore])
useWorkflowConfig(
pipelineId ? `/rag/pipelines/${pipelineId}/workflows/draft/config` : '',
handleUpdateWorkflowConfig,
)
const handleUpdateNodesDefaultConfigs = useCallback((nodesDefaultConfigs: Record<string, any> | Record<string, any>[]) => {
const { setNodesDefaultConfigs } = workflowStore.getState()
let res: Record<string, any> = {}

View File

@ -33,13 +33,6 @@ export const useWorkflowInit = () => {
workflowStore.setState({ appId: appDetail.id, appName: appDetail.name })
}, [appDetail.id, workflowStore])
const handleUpdateWorkflowConfig = useCallback((config: Record<string, any>) => {
const { setWorkflowConfig } = workflowStore.getState()
setWorkflowConfig(config)
}, [workflowStore])
useWorkflowConfig(`/apps/${appDetail.id}/workflows/draft/config`, handleUpdateWorkflowConfig)
const handleUpdateWorkflowFileUploadConfig = useCallback((config: FileUploadConfigResponse) => {
const { setFileUploadConfig } = workflowStore.getState()
setFileUploadConfig(config)

View File

@ -35,8 +35,6 @@ export const NODE_LAYOUT_HORIZONTAL_PADDING = 60
export const NODE_LAYOUT_VERTICAL_PADDING = 60
export const NODE_LAYOUT_MIN_DISTANCE = 100
export const PARALLEL_DEPTH_LIMIT = 3
export const RETRIEVAL_OUTPUT_STRUCT = `{
"content": "",
"title": "",

View File

@ -70,7 +70,7 @@ export const useNodesInteractions = () => {
const reactflow = useReactFlow()
const { store: workflowHistoryStore } = useWorkflowHistoryStore()
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
const { checkNestedParallelLimit, getAfterNodesInSameBranch } = useWorkflow()
const { getAfterNodesInSameBranch } = useWorkflow()
const { getNodesReadOnly } = useNodesReadOnly()
const { getWorkflowReadOnly } = useWorkflowReadOnly()
const { handleSetHelpline } = useHelpline()
@ -436,21 +436,13 @@ export const useNodesInteractions = () => {
draft.push(newEdge)
})
if (checkNestedParallelLimit(newNodes, newEdges, targetNode)) {
setNodes(newNodes)
setEdges(newEdges)
setNodes(newNodes)
setEdges(newEdges)
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.NodeConnect, {
nodeId: targetNode?.id,
})
}
else {
const { setConnectingNodePayload, setEnteringNodePayload }
= workflowStore.getState()
setConnectingNodePayload(undefined)
setEnteringNodePayload(undefined)
}
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.NodeConnect, {
nodeId: targetNode?.id,
})
},
[
getNodesReadOnly,
@ -458,7 +450,6 @@ export const useNodesInteractions = () => {
workflowStore,
handleSyncWorkflowDraft,
saveStateToHistory,
checkNestedParallelLimit,
],
)
@ -934,13 +925,8 @@ export const useNodesInteractions = () => {
if (newEdge) draft.push(newEdge)
})
if (checkNestedParallelLimit(newNodes, newEdges, prevNode)) {
setNodes(newNodes)
setEdges(newEdges)
}
else {
return false
}
setNodes(newNodes)
setEdges(newEdges)
}
if (!prevNodeId && nextNodeId) {
const nextNodeIndex = nodes.findIndex(node => node.id === nextNodeId)
@ -1087,17 +1073,11 @@ export const useNodesInteractions = () => {
draft.push(newEdge)
})
if (checkNestedParallelLimit(newNodes, newEdges, nextNode)) {
setNodes(newNodes)
setEdges(newEdges)
}
else {
return false
}
setNodes(newNodes)
setEdges(newEdges)
}
else {
if (checkNestedParallelLimit(newNodes, edges)) setNodes(newNodes)
else return false
setNodes(newNodes)
}
}
if (prevNodeId && nextNodeId) {
@ -1297,7 +1277,6 @@ export const useNodesInteractions = () => {
saveStateToHistory,
workflowStore,
getAfterNodesInSameBranch,
checkNestedParallelLimit,
nodesMetaDataMap,
],
)

View File

@ -2,7 +2,6 @@ import {
useCallback,
} from 'react'
import { uniqBy } from 'lodash-es'
import { useTranslation } from 'react-i18next'
import {
getIncomers,
getOutgoers,
@ -24,9 +23,7 @@ import {
useStore,
useWorkflowStore,
} from '../store'
import { getParallelInfo } from '../utils'
import {
PARALLEL_DEPTH_LIMIT,
SUPPORT_OUTPUT_VARS_NODE,
} from '../constants'
import type { IterationNodeType } from '../nodes/iteration/types'
@ -44,7 +41,6 @@ import {
import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants'
import { CUSTOM_LOOP_START_NODE } from '@/app/components/workflow/nodes/loop-start/constants'
import { basePath } from '@/utils/var'
import { MAX_PARALLEL_LIMIT } from '@/config'
import { useNodesMetaData } from '.'
export const useIsChatMode = () => {
@ -54,9 +50,7 @@ export const useIsChatMode = () => {
}
export const useWorkflow = () => {
const { t } = useTranslation()
const store = useStoreApi()
const workflowStore = useWorkflowStore()
const { getAvailableBlocks } = useAvailableBlocks()
const { nodesMap } = useNodesMetaData()
@ -290,20 +284,6 @@ export const useWorkflow = () => {
return isUsed
}, [isVarUsedInNodes])
const checkParallelLimit = useCallback((nodeId: string, nodeHandle = 'source') => {
const {
edges,
} = store.getState()
const connectedEdges = edges.filter(edge => edge.source === nodeId && edge.sourceHandle === nodeHandle)
if (connectedEdges.length > MAX_PARALLEL_LIMIT - 1) {
const { setShowTips } = workflowStore.getState()
setShowTips(t('workflow.common.parallelTip.limit', { num: MAX_PARALLEL_LIMIT }))
return false
}
return true
}, [store, workflowStore, t])
const getRootNodesById = useCallback((nodeId: string) => {
const {
getNodes,
@ -374,33 +354,6 @@ export const useWorkflow = () => {
return startNodes
}, [nodesMap, getRootNodesById])
const checkNestedParallelLimit = useCallback((nodes: Node[], edges: Edge[], targetNode?: Node) => {
const startNodes = getStartNodes(nodes, targetNode)
for (let i = 0; i < startNodes.length; i++) {
const {
parallelList,
hasAbnormalEdges,
} = getParallelInfo(startNodes[i], nodes, edges)
const { workflowConfig } = workflowStore.getState()
if (hasAbnormalEdges)
return false
for (let i = 0; i < parallelList.length; i++) {
const parallel = parallelList[i]
if (parallel.depth > (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT)) {
const { setShowTips } = workflowStore.getState()
setShowTips(t('workflow.common.parallelTip.depthLimit', { num: (workflowConfig?.parallel_depth_limit || PARALLEL_DEPTH_LIMIT) }))
return false
}
}
}
return true
}, [t, workflowStore, getStartNodes])
const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => {
const {
edges,
@ -410,9 +363,6 @@ export const useWorkflow = () => {
const sourceNode: Node = nodes.find(node => node.id === source)!
const targetNode: Node = nodes.find(node => node.id === target)!
if (!checkParallelLimit(source!, sourceHandle || 'source'))
return false
if (sourceNode.type === CUSTOM_NOTE_NODE || targetNode.type === CUSTOM_NOTE_NODE)
return false
@ -445,7 +395,7 @@ export const useWorkflow = () => {
}
return !hasCycle(targetNode)
}, [store, checkParallelLimit, getAvailableBlocks])
}, [store, getAvailableBlocks])
return {
getNodeById,
@ -457,8 +407,6 @@ export const useWorkflow = () => {
isVarUsedInNodes,
removeUsedVarInNodes,
isNodeVarsUsedInNodes,
checkParallelLimit,
checkNestedParallelLimit,
isValidConnection,
getBeforeNodeById,
getIterationNodeChildren,

View File

@ -71,7 +71,6 @@ import PanelContextmenu from './panel-contextmenu'
import NodeContextmenu from './node-contextmenu'
import SelectionContextmenu from './selection-contextmenu'
import SyncingDataModal from './syncing-data-modal'
import LimitTips from './limit-tips'
import { setupScrollToNodeListener } from './utils/node-navigation'
import {
useStore,
@ -378,7 +377,6 @@ export const Workflow: FC<WorkflowProps> = memo(({
/>
)
}
<LimitTips />
{children}
<ReactFlow
nodeTypes={nodeTypes}

View File

@ -1,39 +0,0 @@
import {
RiAlertFill,
RiCloseLine,
} from '@remixicon/react'
import { useStore } from './store'
import ActionButton from '@/app/components/base/action-button'
const LimitTips = () => {
const showTips = useStore(s => s.showTips)
const setShowTips = useStore(s => s.setShowTips)
if (!showTips)
return null
return (
<div className='absolute bottom-16 left-1/2 z-[9] flex h-10 -translate-x-1/2 items-center rounded-xl border border-components-panel-border bg-components-panel-bg-blur p-2 shadow-md'>
<div
className='absolute inset-0 rounded-xl opacity-[0.4]'
style={{
background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)',
}}
></div>
<div className='flex h-5 w-5 items-center justify-center'>
<RiAlertFill className='h-4 w-4 text-text-warning-secondary' />
</div>
<div className='system-xs-medium mx-1 px-1 text-text-primary'>
{showTips}
</div>
<ActionButton
className='z-[1]'
onClick={() => setShowTips('')}
>
<RiCloseLine className='h-4 w-4' />
</ActionButton>
</div>
)
}
export default LimitTips

View File

@ -12,7 +12,6 @@ import {
useAvailableBlocks,
useNodesInteractions,
useNodesReadOnly,
useWorkflow,
} from '@/app/components/workflow/hooks'
import BlockSelector from '@/app/components/workflow/block-selector'
import type {
@ -39,7 +38,6 @@ const Add = ({
const { handleNodeAdd } = useNodesInteractions()
const { nodesReadOnly } = useNodesReadOnly()
const { availableNextBlocks } = useAvailableBlocks(nodeData.type, nodeData.isInIteration || nodeData.isInLoop)
const { checkParallelLimit } = useWorkflow()
const handleSelect = useCallback<OnSelectBlock>((type, toolDefaultValue) => {
handleNodeAdd(
@ -52,14 +50,11 @@ const Add = ({
prevNodeSourceHandle: sourceHandle,
},
)
}, [nodeId, sourceHandle, handleNodeAdd])
}, [handleNodeAdd])
const handleOpenChange = useCallback((newOpen: boolean) => {
if (newOpen && !checkParallelLimit(nodeId, sourceHandle))
return
setOpen(newOpen)
}, [checkParallelLimit, nodeId, sourceHandle])
}, [])
const tip = useMemo(() => {
if (isFailBranch)

View File

@ -22,7 +22,6 @@ import {
useIsChatMode,
useNodesInteractions,
useNodesReadOnly,
useWorkflow,
} from '../../../hooks'
import {
useStore,
@ -132,7 +131,6 @@ export const NodeSourceHandle = memo(({
const { availableNextBlocks } = useAvailableBlocks(data.type, data.isInIteration || data.isInLoop)
const isConnectable = !!availableNextBlocks.length
const isChatMode = useIsChatMode()
const { checkParallelLimit } = useWorkflow()
const connected = data._connectedSourceHandleIds?.includes(handleId)
const handleOpenChange = useCallback((v: boolean) => {
@ -140,9 +138,8 @@ export const NodeSourceHandle = memo(({
}, [])
const handleHandleClick = useCallback((e: MouseEvent) => {
e.stopPropagation()
if (checkParallelLimit(id, handleId))
setOpen(v => !v)
}, [checkParallelLimit, id, handleId])
setOpen(v => !v)
}, [])
const handleSelect = useCallback((type: BlockEnum, toolDefaultValue?: ToolDefaultValue) => {
handleNodeAdd(
{

View File

@ -42,6 +42,7 @@ import type { RAGPipelineVariable } from '@/models/pipeline'
import {
AGENT_OUTPUT_STRUCT,
FILE_STRUCT,
HTTP_REQUEST_OUTPUT_STRUCT,
KNOWLEDGE_RETRIEVAL_OUTPUT_STRUCT,
LLM_OUTPUT_STRUCT,
@ -138,6 +139,10 @@ export const varTypeToStructType = (type: VarType): Type => {
[VarType.boolean]: Type.boolean,
[VarType.object]: Type.object,
[VarType.array]: Type.array,
[VarType.arrayString]: Type.array,
[VarType.arrayNumber]: Type.array,
[VarType.arrayObject]: Type.array,
[VarType.arrayFile]: Type.array,
} as any
)[type] || Type.string
)
@ -282,15 +287,6 @@ const findExceptVarInObject = (
children: filteredObj.children,
}
})
if (isFile && Array.isArray(childrenResult)) {
if (childrenResult.length === 0) {
childrenResult = OUTPUT_FILE_SUB_VARIABLES.map(key => ({
variable: key,
type: key === 'size' ? VarType.number : VarType.string,
}))
}
}
}
else {
childrenResult = []
@ -586,17 +582,15 @@ const formatItem = (
variable: outputKey,
type:
output.type === 'array'
? (`Array[${
output.items?.type
? output.items.type.slice(0, 1).toLocaleUpperCase()
+ output.items.type.slice(1)
: 'Unknown'
? (`Array[${output.items?.type
? output.items.type.slice(0, 1).toLocaleUpperCase()
+ output.items.type.slice(1)
: 'Unknown'
}]` as VarType)
: (`${
output.type
? output.type.slice(0, 1).toLocaleUpperCase()
+ output.type.slice(1)
: 'Unknown'
: (`${output.type
? output.type.slice(0, 1).toLocaleUpperCase()
+ output.type.slice(1)
: 'Unknown'
}` as VarType),
})
},
@ -690,9 +684,10 @@ const formatItem = (
const children = (() => {
if (isFile) {
return OUTPUT_FILE_SUB_VARIABLES.map((key) => {
const def = FILE_STRUCT.find(c => c.variable === key)
return {
variable: key,
type: key === 'size' ? VarType.number : VarType.string,
type: def?.type || VarType.string,
}
})
}
@ -714,9 +709,10 @@ const formatItem = (
if (isFile) {
return {
children: OUTPUT_FILE_SUB_VARIABLES.map((key) => {
const def = FILE_STRUCT.find(c => c.variable === key)
return {
variable: key,
type: key === 'size' ? VarType.number : VarType.string,
type: def?.type || VarType.string,
}
}),
}

View File

@ -18,7 +18,6 @@ import { Type } from '../../../llm/types'
import PickerStructurePanel from '@/app/components/workflow/nodes/_base/components/variable/object-child-tree-panel/picker'
import { isSpecialVar, varTypeToStructType } from './utils'
import type { Field } from '@/app/components/workflow/nodes/llm/types'
import { FILE_STRUCT } from '@/app/components/workflow/constants'
import { noop } from 'lodash-es'
import { CodeAssistant, MagicEdit } from '@/app/components/base/icons/src/vender/line/general'
import ManageInputField from './manage-input-field'
@ -106,8 +105,9 @@ const Item: FC<ItemProps> = ({
const objStructuredOutput: StructuredOutput | null = useMemo(() => {
if (!isObj) return null
const properties: Record<string, Field> = {};
(isFile ? FILE_STRUCT : (itemData.children as Var[])).forEach((c) => {
const properties: Record<string, Field> = {}
const childrenVars = (itemData.children as Var[]) || []
childrenVars.forEach((c) => {
properties[c.variable] = {
type: varTypeToStructType(c.type),
}
@ -120,7 +120,7 @@ const Item: FC<ItemProps> = ({
additionalProperties: false,
},
}
}, [isFile, isObj, itemData.children])
}, [isObj, itemData.children])
const structuredOutput = (() => {
if (isStructureOutput)
@ -448,4 +448,5 @@ const VarReferenceVars: FC<Props> = ({
</>
)
}
export default React.memo(VarReferenceVars)

View File

@ -55,6 +55,7 @@ const Panel: FC<NodePanelProps<ListFilterNodeType>> = ({
value={inputs.variable || []}
onChange={handleVarChanges}
filterVar={filterVar}
isSupportFileVar={false}
typePlaceHolder='Array'
/>
</Field>

View File

@ -29,10 +29,6 @@ export type WorkflowSliceShape = {
setControlPromptEditorRerenderKey: (controlPromptEditorRerenderKey: number) => void
showImportDSLModal: boolean
setShowImportDSLModal: (showImportDSLModal: boolean) => void
showTips: string
setShowTips: (showTips: string) => void
workflowConfig?: Record<string, any>
setWorkflowConfig: (workflowConfig: Record<string, any>) => void
fileUploadConfig?: FileUploadConfigResponse
setFileUploadConfig: (fileUploadConfig: FileUploadConfigResponse) => void
}
@ -59,10 +55,6 @@ export const createWorkflowSlice: StateCreator<WorkflowSliceShape> = set => ({
setControlPromptEditorRerenderKey: controlPromptEditorRerenderKey => set(() => ({ controlPromptEditorRerenderKey })),
showImportDSLModal: false,
setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })),
showTips: '',
setShowTips: showTips => set(() => ({ showTips })),
workflowConfig: undefined,
setWorkflowConfig: workflowConfig => set(() => ({ workflowConfig })),
fileUploadConfig: undefined,
setFileUploadConfig: fileUploadConfig => set(() => ({ fileUploadConfig })),
})

View File

@ -1,12 +1,8 @@
import {
getConnectedEdges,
getIncomers,
getOutgoers,
} from 'reactflow'
import { v4 as uuid4 } from 'uuid'
import {
groupBy,
isEqual,
uniqBy,
} from 'lodash-es'
import type {
@ -168,158 +164,6 @@ export const changeNodesAndEdgesId = (nodes: Node[], edges: Edge[]) => {
return [newNodes, newEdges] as [Node[], Edge[]]
}
type ParallelInfoItem = {
parallelNodeId: string
depth: number
isBranch?: boolean
}
type NodeParallelInfo = {
parallelNodeId: string
edgeHandleId: string
depth: number
}
type NodeHandle = {
node: Node
handle: string
}
type NodeStreamInfo = {
upstreamNodes: Set<string>
downstreamEdges: Set<string>
}
export const getParallelInfo = (startNode: Node, nodes: Node[], edges: Edge[]) => {
if (!startNode)
throw new Error('Start node not found')
const parallelList = [] as ParallelInfoItem[]
const nextNodeHandles = [{ node: startNode, handle: 'source' }]
let hasAbnormalEdges = false
const traverse = (firstNodeHandle: NodeHandle) => {
const nodeEdgesSet = {} as Record<string, Set<string>>
const totalEdgesSet = new Set<string>()
const nextHandles = [firstNodeHandle]
const streamInfo = {} as Record<string, NodeStreamInfo>
const parallelListItem = {
parallelNodeId: '',
depth: 0,
} as ParallelInfoItem
const nodeParallelInfoMap = {} as Record<string, NodeParallelInfo>
nodeParallelInfoMap[firstNodeHandle.node.id] = {
parallelNodeId: '',
edgeHandleId: '',
depth: 0,
}
while (nextHandles.length) {
const currentNodeHandle = nextHandles.shift()!
const { node: currentNode, handle: currentHandle = 'source' } = currentNodeHandle
const currentNodeHandleKey = currentNode.id
const connectedEdges = edges.filter(edge => edge.source === currentNode.id && edge.sourceHandle === currentHandle)
const connectedEdgesLength = connectedEdges.length
const outgoers = nodes.filter(node => connectedEdges.some(edge => edge.target === node.id))
const incomers = getIncomers(currentNode, nodes, edges)
if (!streamInfo[currentNodeHandleKey]) {
streamInfo[currentNodeHandleKey] = {
upstreamNodes: new Set<string>(),
downstreamEdges: new Set<string>(),
}
}
if (nodeEdgesSet[currentNodeHandleKey]?.size > 0 && incomers.length > 1) {
const newSet = new Set<string>()
for (const item of totalEdgesSet) {
if (!streamInfo[currentNodeHandleKey].downstreamEdges.has(item))
newSet.add(item)
}
if (isEqual(nodeEdgesSet[currentNodeHandleKey], newSet)) {
parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
nextNodeHandles.push({ node: currentNode, handle: currentHandle })
break
}
}
if (nodeParallelInfoMap[currentNode.id].depth > parallelListItem.depth)
parallelListItem.depth = nodeParallelInfoMap[currentNode.id].depth
outgoers.forEach((outgoer) => {
const outgoerConnectedEdges = getConnectedEdges([outgoer], edges).filter(edge => edge.source === outgoer.id)
const sourceEdgesGroup = groupBy(outgoerConnectedEdges, 'sourceHandle')
const incomers = getIncomers(outgoer, nodes, edges)
if (outgoers.length > 1 && incomers.length > 1)
hasAbnormalEdges = true
Object.keys(sourceEdgesGroup).forEach((sourceHandle) => {
nextHandles.push({ node: outgoer, handle: sourceHandle })
})
if (!outgoerConnectedEdges.length)
nextHandles.push({ node: outgoer, handle: 'source' })
const outgoerKey = outgoer.id
if (!nodeEdgesSet[outgoerKey])
nodeEdgesSet[outgoerKey] = new Set<string>()
if (nodeEdgesSet[currentNodeHandleKey]) {
for (const item of nodeEdgesSet[currentNodeHandleKey])
nodeEdgesSet[outgoerKey].add(item)
}
if (!streamInfo[outgoerKey]) {
streamInfo[outgoerKey] = {
upstreamNodes: new Set<string>(),
downstreamEdges: new Set<string>(),
}
}
if (!nodeParallelInfoMap[outgoer.id]) {
nodeParallelInfoMap[outgoer.id] = {
...nodeParallelInfoMap[currentNode.id],
}
}
if (connectedEdgesLength > 1) {
const edge = connectedEdges.find(edge => edge.target === outgoer.id)!
nodeEdgesSet[outgoerKey].add(edge.id)
totalEdgesSet.add(edge.id)
streamInfo[currentNodeHandleKey].downstreamEdges.add(edge.id)
streamInfo[outgoerKey].upstreamNodes.add(currentNodeHandleKey)
for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
streamInfo[item].downstreamEdges.add(edge.id)
if (!parallelListItem.parallelNodeId)
parallelListItem.parallelNodeId = currentNode.id
const prevDepth = nodeParallelInfoMap[currentNode.id].depth + 1
const currentDepth = nodeParallelInfoMap[outgoer.id].depth
nodeParallelInfoMap[outgoer.id].depth = Math.max(prevDepth, currentDepth)
}
else {
for (const item of streamInfo[currentNodeHandleKey].upstreamNodes)
streamInfo[outgoerKey].upstreamNodes.add(item)
nodeParallelInfoMap[outgoer.id].depth = nodeParallelInfoMap[currentNode.id].depth
}
})
}
parallelList.push(parallelListItem)
}
while (nextNodeHandles.length) {
const nodeHandle = nextNodeHandles.shift()!
traverse(nodeHandle)
}
return {
parallelList,
hasAbnormalEdges,
}
}
export const hasErrorHandleNode = (nodeType?: BlockEnum) => {
return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code
}

View File

@ -4,12 +4,12 @@ const translation = {
title: '空白の知識パイプライン',
description: 'データ処理と構造を完全に制御できるカスタムパイプラインをゼロから作成します。',
},
backToKnowledge: '知識に戻る',
backToKnowledge: 'ナレッジベースに戻る',
caution: '注意',
importDSL: 'DSLファイルからインポートする',
errorTip: 'ナレッジベースの作成に失敗しました',
createKnowledge: '知識を創造する',
successTip: '知識ベースが正常に作成されました',
createKnowledge: 'ナレッジベースを作成する',
successTip: 'ナレッジベースが正常に作成されました',
},
templates: {
customized: 'カスタマイズされた',
@ -21,10 +21,10 @@ const translation = {
preview: 'プレビュー',
dataSource: 'データソース',
editInfo: '情報を編集する',
exportPipeline: '輸出パイプライン',
exportPipeline: 'パイプラインをエクスポートする',
saveAndProcess: '保存して処理する',
backToDataSource: 'データソースに戻る',
useTemplate: 'この知識パイプラインを使用してください',
useTemplate: 'このナレッジパイプラインを使用してください',
process: 'プロセス',
},
deletePipeline: {
@ -37,7 +37,7 @@ const translation = {
tip: '<CustomLink>ドキュメントに移動</CustomLink>して、ドキュメントを追加または管理してください。',
},
error: {
message: '知識パイプラインの公開に失敗しました',
message: 'ナレッジパイプラインの公開に失敗しました',
},
},
publishTemplate: {
@ -147,19 +147,19 @@ const translation = {
content: 'この操作は永久的です。以前の方法に戻すことはできません。変換することを確認してください。',
},
warning: 'この操作は元に戻せません。',
title: '知識パイプラインに変換する',
title: 'ナレッジパイプラインに変換する',
successMessage: 'データセットをパイプラインに正常に変換しました',
errorMessage: 'データセットをパイプラインに変換できませんでした',
descriptionChunk1: '既存の知識ベースを文書処理のためにナレッジパイプラインを使用するように変換できます。',
descriptionChunk1: '既存のナレッジベースを文書処理のためにナレッジパイプラインを使用するように変換できます。',
descriptionChunk2: '— よりオープンで柔軟なアプローチを採用し、私たちのマーケットプレイスからのプラグインへのアクセスを提供します。これにより、すべての将来のドキュメントに新しい処理方法が適用されることになります。',
},
knowledgeNameAndIcon: '知識の名前とアイコン',
knowledgeNameAndIcon: 'ナレッジベースの名前とアイコン',
inputField: '入力フィールド',
pipelineNameAndIcon: 'パイプライン名とアイコン',
knowledgePermissions: '権限',
knowledgeNameAndIconPlaceholder: 'ナレッジベースの名前を入力してください',
editPipelineInfo: 'パイプライン情報を編集する',
knowledgeDescription: '知識の説明',
knowledgeDescription: 'ナレッジベースの説明',
knowledgeDescriptionPlaceholder: 'このナレッジベースに何が含まれているかを説明してください。詳細な説明は、AIがデータセットの内容により正確にアクセスできるようにします。空の場合、Difyはデフォルトのヒット戦略を使用します。オプション',
}

View File

@ -91,6 +91,7 @@ const remoteImageURLs = [hasSetWebPrefix ? new URL(`${process.env.NEXT_PUBLIC_WE
/** @type {import('next').NextConfig} */
const nextConfig = {
basePath: process.env.NEXT_PUBLIC_BASE_PATH || '',
transpilePackages: ['echarts', 'zrender'],
turbopack: {
rules: codeInspectorPlugin({
bundler: 'turbopack'

View File

@ -24,12 +24,9 @@
"build:docker": "next build && node scripts/optimize-standalone.js",
"start": "cp -r .next/static .next/standalone/.next/static && cp -r public .next/standalone/public && cross-env PORT=$npm_config_port HOSTNAME=$npm_config_host node .next/standalone/server.js",
"lint": "eslint --concurrency=auto --cache --cache-location node_modules/.cache/eslint/.eslint-cache",
"lint-only-show-error": "eslint --concurrency=auto --cache --cache-location node_modules/.cache/eslint/.eslint-cache --quiet",
"fix": "eslint --concurrency=auto --fix .",
"eslint": "eslint --concurrency=auto --cache --cache-location node_modules/.cache/eslint/.eslint-cache",
"eslint-fix": "eslint --concurrency=auto --cache --cache-location node_modules/.cache/eslint/.eslint-cache --fix",
"eslint-fix-only-show-error": "eslint --concurrency=auto --cache --cache-location node_modules/.cache/eslint/.eslint-cache --fix --quiet",
"eslint-complexity": "eslint --concurrency=auto --rule 'complexity: [error, {max: 15}]' --quiet",
"lint:fix": "eslint --concurrency=auto --cache --cache-location node_modules/.cache/eslint/.eslint-cache --fix",
"lint:quiet": "eslint --concurrency=auto --cache --cache-location node_modules/.cache/eslint/.eslint-cache --quiet",
"lint:complexity": "eslint --concurrency=auto --cache --cache-location node_modules/.cache/eslint/.eslint-cache --rule 'complexity: [error, {max: 15}]' --quiet",
"prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky ./web/.husky",
"gen-icons": "node ./app/components/base/icons/script.mjs",
"uglify-embed": "node ./bin/uglify-embed",