This commit is contained in:
Joel 2026-04-20 13:41:51 +08:00
commit 3bb3670cb5
1312 changed files with 10841 additions and 7871 deletions

View File

@ -20,11 +20,11 @@
```typescript
// ❌ WRONG: Don't mock base components
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
vi.mock('@/app/components/base/ui/button', () => ({ children }: any) => <button>{children}</button>)
vi.mock('@langgenius/dify-ui/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
import { Button } from '@/app/components/base/ui/button'
import { Button } from '@langgenius/dify-ui/button'
// They will render normally in tests
```

View File

@ -76,13 +76,11 @@ jobs:
diff += '\\n\\n... (truncated) ...';
}
const body = diff.trim()
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body,
});
if (diff.trim()) {
await github.rest.issues.createComment({
issue_number: prNumber,
owner: context.repo.owner,
repo: context.repo.repo,
body: '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>',
});
}

View File

@ -89,3 +89,37 @@ jobs:
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
dify-ui-test:
name: dify-ui Tests
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
working-directory: ./packages/dify-ui
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Install Chromium for Browser Mode
run: vp exec playwright install --with-deps chromium
- name: Run dify-ui tests
run: vp test run --coverage --silent=passed-only
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: packages/dify-ui/coverage
flags: dify-ui
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -0,0 +1 @@
CURRENT_APP_DSL_VERSION = "0.6.0"

View File

@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
return value
try:

View File

@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
return str(value_type.exposed_type())
class FullContentDict(TypedDict):
@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
result: FullContentDict = {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"value_type": str(variable_file.value_type.exposed_type()),
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.exposed_type().value,
"value_type": str(v.value_type.exposed_type()),
"value": v.value,
# Do not track edited for env vars.
"edited": False,

View File

@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:

View File

@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile

View File

@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class ModelConfigConverter:

View File

@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)

View File

@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile

View File

@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
from core.helper.ssrf_proxy import ssrf_proxy
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
from graphon.file import FileTransferMethod
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
from graphon.file.protocols import WorkflowFileRuntimeProtocol
from graphon.file.runtime import set_workflow_file_runtime
from graphon.http.protocols import HttpResponseProtocol
if TYPE_CHECKING:
from graphon.file import File
@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)

View File

@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
execution.exceptions_count = runtime_state.exceptions_count
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
def _update_node_execution(
self,

View File

@ -352,11 +352,11 @@ class DatasourceManager:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.CUSTOM,
file_type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(record_id=str(upload_file.id)),

View File

@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
FormType,
ProviderEntity,
)
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
@ -363,7 +363,7 @@ class ProviderConfiguration(BaseModel):
)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
@ -912,7 +912,7 @@ class ProviderConfiguration(BaseModel):
)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials

View File

@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
inputs_json_str = dumps_with_segments(inputs).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded

View File

@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)

View File

@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
from graphon.http.response import HttpResponse
logger = logging.getLogger(__name__)
@ -267,4 +268,47 @@ class SSRFProxy:
return patch(url=url, max_retries=max_retries, **kwargs)
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
return HttpResponse(
status_code=response.status_code,
headers=dict(response.headers),
content=response.content,
url=str(response.url) if response.url else None,
reason_phrase=response.reason_phrase,
fallback_text=response.text,
)
class GraphonSSRFProxy:
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
@property
def max_retries_exceeded_error(self) -> type[Exception]:
return max_retries_exceeded_error
@property
def request_error(self) -> type[Exception]:
return request_error
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
ssrf_proxy = SSRFProxy()
graphon_ssrf_proxy = GraphonSSRFProxy()

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
from core.entities import PluginCredentialType
@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class ModelInstance:
@ -168,7 +170,7 @@ class ModelInstance:
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@ -193,7 +195,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@ -213,7 +215,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@ -235,7 +237,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
@ -252,7 +254,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@ -277,7 +279,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
@ -305,7 +307,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
@ -324,7 +326,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
@ -340,7 +342,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
@ -357,14 +359,14 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke

View File

@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_Hans
provider_schema.icon_small_dark.zh_hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_US
else provider_schema.icon_small_dark.en_us
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")

View File

@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class AgentHistoryPromptTransform(PromptTransform):

View File

@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from models.dataset import Embedding

View File

@ -3,6 +3,7 @@
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
import inspect
import logging
import mimetypes
import os
@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
_closed: bool
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
self._closed = False
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
def close(self) -> None:
"""Best-effort cleanup for downloaded temporary files."""
if getattr(self, "_closed", False):
return
self._closed = True
temp_file = getattr(self, "temp_file", None)
if temp_file is None:
return
try:
close_result = temp_file.close()
if inspect.isawaitable(close_result):
close_awaitable = getattr(close_result, "close", None)
if callable(close_awaitable):
close_awaitable()
except Exception:
logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
def __del__(self):
if hasattr(self, "temp_file"):
self.temp_file.close()
self.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""

View File

@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(

View File

@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.helper import parse_uuid_str_or_none
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
@ -517,11 +517,11 @@ class DatasetRetrieval:
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
id=upload_file.id,
file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(

View File

@ -9,7 +9,7 @@ from typing import Any, Literal
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):

View File

@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,

View File

@ -28,7 +28,7 @@ class ToolFileManager:
def _build_graph_file_reference(tool_file: ToolFile) -> File:
extension = guess_extension(tool_file.mimetype) or ".bin"
return File(
type=get_file_type_by_mime_type(tool_file.mimetype),
file_type=get_file_type_by_mime_type(tool_file.mimetype),
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),

View File

@ -1082,7 +1082,12 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.value
if not isinstance(variable_selector, list) or not all(
isinstance(selector_part, str) for selector_part in variable_selector
):
raise ToolParameterError("Variable tool input must be a variable selector")
variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value

View File

@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ToolModelInvoke

View File

@ -357,7 +357,10 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
transfer_method_value = file_dict.get("transfer_method")
if not isinstance(transfer_method_value, str):
raise ValueError("Workflow file mapping is missing a valid transfer_method")
transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id

View File

@ -1,8 +1,8 @@
"""Workflow-layer adapters for legacy human-input payload keys.
"""Workflow-to-Graphon adapters for persisted node payloads.
Stored workflow graphs and editor payloads may still use Dify-specific human
input recipient keys. Normalize them here before handing configs to
`graphon` so graph-owned models only see graph-neutral field names.
Stored workflow graphs and editor payloads still contain a small set of
Dify-owned field spellings and value shapes. Adapt them here before handing the
payload to Graphon so Graphon-owned models only see current contracts.
"""
from __future__ import annotations
@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
return None
def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
normalized = normalize_human_input_node_data_for_graph(node_data)
normalized = adapt_human_input_node_data_for_graph(node_data)
raw_delivery_methods = normalized.get("delivery_methods")
if not isinstance(raw_delivery_methods, list):
return []
@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
return False
def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
return normalized
return normalize_human_input_node_data_for_graph(normalized)
node_type = normalized.get("type")
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
return adapt_human_input_node_data_for_graph(normalized)
if node_type == BuiltinNodeTypes.TOOL:
return _adapt_tool_node_data_for_graph(normalized)
return normalized
def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_config)
if normalized is None:
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
if data_mapping is None:
return normalized
normalized["data"] = normalize_node_data_for_graph(data_mapping)
normalized["data"] = adapt_node_data_for_graph(data_mapping)
return normalized
def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(node_data)
raw_tool_configurations = normalized.get("tool_configurations")
if not isinstance(raw_tool_configurations, Mapping):
return normalized
existing_tool_parameters = normalized.get("tool_parameters")
normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
normalized_tool_configurations: dict[str, Any] = {}
found_legacy_tool_inputs = False
for name, value in raw_tool_configurations.items():
if not isinstance(value, Mapping):
normalized_tool_configurations[name] = value
continue
input_type = value.get("type")
input_value = value.get("value")
if input_type not in {"mixed", "variable", "constant"}:
normalized_tool_configurations[name] = value
continue
found_legacy_tool_inputs = True
normalized_tool_parameters.setdefault(name, dict(value))
flattened_value = _flatten_legacy_tool_configuration_value(
input_type=input_type,
input_value=input_value,
)
if flattened_value is not None:
normalized_tool_configurations[name] = flattened_value
if not found_legacy_tool_inputs:
return normalized
normalized["tool_parameters"] = normalized_tool_parameters
normalized["tool_configurations"] = normalized_tool_configurations
return normalized
def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
return input_value
if (
input_type == "variable"
and isinstance(input_value, list)
and all(isinstance(item, str) for item in input_value)
):
return "{{#" + ".".join(input_value) + "#}}"
return None
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)
@ -291,9 +349,9 @@ __all__ = [
"MemberRecipient",
"WebAppDeliveryMethod",
"_WebAppDeliveryConfig",
"adapt_human_input_node_data_for_graph",
"adapt_node_config_for_graph",
"adapt_node_data_for_graph",
"is_human_input_webapp_enabled",
"normalize_human_input_node_data_for_graph",
"normalize_node_config_for_graph",
"normalize_node_data_for_graph",
"parse_human_input_delivery_methods",
]

View File

@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.human_input_compat import normalize_node_config_for_graph
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.file.file_manager import file_manager
from graphon.graph.graph import NodeFactory
from graphon.model_runtime.memory import PromptMessageMemory
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.base.node import Node
from graphon.nodes.code.code_node import WorkflowCodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
"""Resolve the production node class for the requested type/version."""
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_http_client = graphon_ssrf_proxy
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
# Graph configs are initially validated against permissive shared node data.
# Re-validate using the resolved node class so workflow-local node schemas
# stay explicit and constructors receive the concrete typed payload.
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
return node_class.validate_node_data(node_data)
validate_node_data = getattr(node_class, "validate_node_data", None)
if callable(validate_node_data):
return cast("BaseNodeData", validate_node_data(node_data))
return node_data
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .human_input_compat import (
from .human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
DeliveryMethodType,
@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
@overload
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: Mapping[str, Any],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
stream=stream,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Mapping[str, Any],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,

View File

@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: AgentNodeData,
*,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import SystemVariableKey, get_system_segment
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: DatasourceNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: KnowledgeIndexNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
super().__init__(id, config, graph_init_params, graph_runtime_state)
super().__init__(
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()

View File

@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file_reference import parse_file_reference
from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
@ -50,6 +49,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
if value is None or isinstance(value, (str, float)):
return value
if isinstance(value, int) and not isinstance(value, bool):
return value
return str(value)
def _normalize_metadata_filter_sequence_item(value: object) -> str:
return value if isinstance(value, str) else str(value)
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
id: str,
config: NodeConfigDict,
node_id: str,
config: KnowledgeRetrievalNodeData,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
) -> None:
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = segment_group.value[0].to_object()
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values = []
for v in value: # type: ignore
resolved_values: list[str] = []
for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(segment_group.value[0].to_object())
resolved_values.append(
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
)
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values

View File

@ -148,11 +148,11 @@ def _build_from_local_file(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
@ -196,11 +196,11 @@ def _build_from_remote_url(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
@ -222,9 +222,9 @@ def _build_from_remote_url(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=filename,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
@ -263,9 +263,9 @@ def _build_from_tool_file(
)
return File(
id=mapping.get("id"),
file_id=mapping.get("id"),
filename=tool_file.name,
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
@ -306,9 +306,9 @@ def _build_from_datasource_file(
)
return File(
id=mapping.get("datasource_file_id"),
file_id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
type=file_type,
file_type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),

View File

@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
return v.value_type.exposed_type().value
return str(v.value_type.exposed_type())
else:
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
return value_type.exposed_type().value
return str(value_type.exposed_type())

View File

@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
return str(exposed_type().value)
return str(exposed_type())
if isinstance(value, str):
try:
return str(SegmentType(value).exposed_type().value)
return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:

View File

@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
"value_type": value.value_type.exposed_type().value,
"value_type": str(value.value_type.exposed_type()),
"description": value.description,
}
if isinstance(value, dict):

View File

@ -6,8 +6,8 @@ from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
session.add(new_data_source_binding)
session.commit()
def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token,
)
)
)
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
if data_source_binding:
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
session.add(new_data_source_binding)
session.commit()
def sync_data_source(self, binding_id: str) -> None:
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
with session_factory.create_session() as session:
data_source_binding = session.scalar(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False,
)
)
)
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
session.commit()
else:
raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []

View File

@ -3,6 +3,7 @@
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
@ -36,24 +37,27 @@ class WorkflowComment(Base):
__tablename__ = "workflow_comments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
id: Mapped[str] = mapped_column(
StringUUID, server_default=sa.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
position_x: Mapped[float] = mapped_column(sa.Float)
position_y: Mapped[float] = mapped_column(sa.Float)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
resolved: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
# Relationships
@ -143,23 +147,22 @@ class WorkflowCommentReply(Base):
__tablename__ = "workflow_comment_replies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
id: Mapped[str] = mapped_column(
StringUUID, server_default=sa.text("uuidv7()"))
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
comment: Mapped["WorkflowComment"] = relationship(
"WorkflowComment", back_populates="replies")
@property
def created_by_account(self):
@ -187,24 +190,23 @@ class WorkflowCommentMention(Base):
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
id: Mapped[str] = mapped_column(
StringUUID, server_default=sa.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
comment: Mapped["WorkflowComment"] = relationship(
"WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]
] = relationship("WorkflowCommentReply")
@property
def mentioned_user_account(self):

View File

@ -6,7 +6,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
from core.workflow.human_input_compat import DeliveryMethodType
from core.workflow.human_input_adapter import DeliveryMethodType
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.helper import generate_string

View File

@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
from graphon.file import File, FileTransferMethod
from graphon.file import File, FileTransferMethod, FileType
from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
@lru_cache(maxsize=1)
@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
return tenant_resolver()
def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
"""Build a graph `File` directly from serialized metadata."""
def _coerce_file_type(value: Any) -> FileType:
if isinstance(value, FileType):
return value
if isinstance(value, str):
return FileType.value_of(value)
raise ValueError("file type is required in file mapping")
mapping = dict(file_mapping)
transfer_method_value = mapping.get("transfer_method")
if isinstance(transfer_method_value, FileTransferMethod):
transfer_method = transfer_method_value
elif isinstance(transfer_method_value, str):
transfer_method = FileTransferMethod.value_of(transfer_method_value)
else:
raise ValueError("transfer_method is required in file mapping")
file_id = mapping.get("file_id")
if not isinstance(file_id, str) or not file_id:
legacy_id = mapping.get("id")
file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
related_id = resolve_file_record_id(mapping)
if related_id is None:
raw_related_id = mapping.get("related_id")
related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
remote_url = mapping.get("remote_url")
if not isinstance(remote_url, str) or not remote_url:
url = mapping.get("url")
remote_url = url if isinstance(url, str) and url else None
reference = mapping.get("reference")
if not isinstance(reference, str) or not reference:
reference = None
filename = mapping.get("filename")
if not isinstance(filename, str):
filename = None
extension = mapping.get("extension")
if not isinstance(extension, str):
extension = None
mime_type = mapping.get("mime_type")
if not isinstance(mime_type, str):
mime_type = None
size = mapping.get("size", -1)
if not isinstance(size, int):
size = -1
storage_key = mapping.get("storage_key")
if not isinstance(storage_key, str):
storage_key = None
tenant_id = mapping.get("tenant_id")
if not isinstance(tenant_id, str):
tenant_id = None
dify_model_identity = mapping.get("dify_model_identity")
if not isinstance(dify_model_identity, str):
dify_model_identity = FILE_MODEL_IDENTITY
tool_file_id = mapping.get("tool_file_id")
if not isinstance(tool_file_id, str):
tool_file_id = None
upload_file_id = mapping.get("upload_file_id")
if not isinstance(upload_file_id, str):
upload_file_id = None
datasource_file_id = mapping.get("datasource_file_id")
if not isinstance(datasource_file_id, str):
datasource_file_id = None
return File(
file_id=file_id,
tenant_id=tenant_id,
file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
transfer_method=transfer_method,
remote_url=remote_url,
reference=reference,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
storage_key=storage_key,
dify_model_identity=dify_model_identity,
url=remote_url,
tool_file_id=tool_file_id,
upload_file_id=upload_file_id,
datasource_file_id=datasource_file_id,
)
def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
"""Recursively rebuild serialized graph file payloads into `File` objects.
`graphon` 0.2.2 no longer accepts legacy serialized file mappings via
`model_validate_json()`. Dify keeps this recovery path at the model boundary
so historical JSON blobs remain readable without reintroducing global graph
patches or test-local coercion.
"""
if isinstance(value, list):
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
if isinstance(value, dict):
if maybe_file_object(value):
return build_file_from_mapping_without_lookup(file_mapping=value)
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
return value
def build_file_from_stored_mapping(
*,
file_mapping: Mapping[str, Any],
@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
remote_url = mapping.get("remote_url")
if not isinstance(remote_url, str) or not remote_url:
url = mapping.get("url")
if isinstance(url, str) and url:
mapping["remote_url"] = url
return File.model_validate(mapping)
return build_file_from_mapping_without_lookup(file_mapping=mapping)
return file_factory.build_from_mapping(
mapping=mapping,

View File

@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.workflow.human_input_compat import normalize_node_config_for_graph
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
from .utils.file_input_compat import build_file_from_stored_mapping
from .utils.file_input_compat import (
build_file_from_mapping_without_lookup,
build_file_from_stored_mapping,
)
logger = logging.getLogger(__name__)
@ -292,7 +295,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
@staticmethod
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
@ -1690,7 +1693,7 @@ class WorkflowDraftVariable(Base):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
return File.model_validate(normalized_file)
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
@ -1700,7 +1703,7 @@ class WorkflowDraftVariable(Base):
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
file_list.append(File.model_validate(normalized_file))
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)

View File

@ -1,7 +1,6 @@
"""
Tencent APM tracing implementation with separated concerns
"""
"""Tencent APM tracing with idempotent client cleanup."""
import inspect
import logging
from sqlalchemy import select
@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
explicitly in tests or implicitly during garbage collection, so shutdown
must be safe to call multiple times.
"""
trace_client: TencentTraceClient
_closed: bool
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
self._closed = False
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance):
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
def __del__(self):
"""Ensure proper cleanup on garbage collection."""
def close(self) -> None:
"""Synchronously and idempotently shutdown the underlying trace client."""
if getattr(self, "_closed", False):
return
self._closed = True
trace_client = getattr(self, "trace_client", None)
if trace_client is None:
return
try:
if hasattr(self, "trace_client"):
self.trace_client.shutdown()
shutdown_result = trace_client.shutdown()
if inspect.isawaitable(shutdown_result):
close_awaitable = getattr(shutdown_result, "close", None)
if callable(close_awaitable):
close_awaitable()
except Exception:
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
def __del__(self):
"""Ensure best-effort cleanup on garbage collection without retrying shutdown."""
self.close()

View File

@ -1,5 +1,7 @@
import gc
import logging
from unittest.mock import MagicMock, patch
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dify_trace_tencent.config import TencentConfig
@ -632,13 +634,38 @@ class TestTencentDataTrace:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_trace_duration(trace_info)
def test_del(self, tencent_data_trace):
def test_close(self, tencent_data_trace):
client = tencent_data_trace.trace_client
tencent_data_trace.__del__()
tencent_data_trace.close()
client.shutdown.assert_called_once()
def test_del_exception(self, tencent_data_trace):
def test_close_is_idempotent(self, tencent_data_trace):
client = tencent_data_trace.trace_client
tencent_data_trace.close()
tencent_data_trace.close()
client.shutdown.assert_called_once()
def test_close_exception(self, tencent_data_trace):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.__del__()
tencent_data_trace.close()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
shutdown = AsyncMock()
tencent_data_trace.trace_client.shutdown = shutdown
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
tencent_data_trace.close()
gc.collect()
shutdown.assert_called_once()
assert not [
warning
for warning in caught
if issubclass(warning.category, RuntimeWarning)
and "AsyncMockMixin._execute_mock_call" in str(warning.message)
]

View File

@ -44,7 +44,7 @@ dependencies = [
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]~=0.7.0",
"graphon~=0.1.2",
"graphon~=0.2.2",
"httpx-sse~=0.4.0",
"json-repair~=0.59.2",
]
@ -173,7 +173,7 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.60.0",
"pyrefly>=0.61.1",
"xinference-client>=2.4.0",
]

View File

@ -7,8 +7,8 @@ from sqlalchemy import select
import app
from configs import dify_config
from core.db.session_factory import session_factory
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
from models import Account, Tenant, TenantAccountJoin
@ -33,67 +33,68 @@ def mail_clean_document_notify_task():
# send document clean notify mail
try:
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
).all()
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
features = FeatureService.get_features(tenant_id)
plan = features.billing.subscription.plan
if plan != CloudPlan.SANDBOX:
knowledge_details = []
# check tenant
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
# check current owner
current_owner_join = db.session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
)
if not current_owner_join:
continue
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
if not account:
continue
with session_factory.create_session() as session:
dataset_auto_disable_logs = session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False))
).all()
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
features = FeatureService.get_features(tenant_id)
plan = features.billing.subscription.plan
if plan != CloudPlan.SANDBOX:
knowledge_details = []
# check tenant
tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id))
if not tenant:
continue
# check current owner
current_owner_join = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
)
if not current_owner_join:
continue
account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
if not account:
continue
dataset_auto_dataset_map = {} # type: ignore
dataset_auto_dataset_map = {} # type: ignore
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
language_code="en-US",
to=account.email,
template_context={
"userName": account.email,
"knowledge_details": knowledge_details,
"url": url,
},
)
# update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
language_code="en-US",
to=account.email,
template_context={
"userName": account.email,
"knowledge_details": knowledge_details,
"url": url,
},
)
# update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_disable_log.notified = True
db.session.commit()
dataset_auto_disable_log.notified = True
session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:

View File

@ -17,6 +17,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from constants.dsl_version import CURRENT_APP_DSL_VERSION
from core.helper import ssrf_proxy
from core.plugin.entities.plugin import PluginDependency
from core.trigger.constants import (
@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
CURRENT_DSL_VERSION = "0.6.0"
CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION
class Import(BaseModel):

View File

@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account

View File

@ -30,7 +30,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user

View File

@ -3,6 +3,7 @@ from enum import StrEnum
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from constants.dsl_version import CURRENT_APP_DSL_VERSION
from enums.cloud_plan import CloudPlan
from enums.hosted_provider import HostedTrialProvider
from services.billing_service import BillingService
@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel):
class SystemFeatureModel(BaseModel):
app_dsl_version: str = ""
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
enable_marketplace: bool = False
@ -225,6 +227,7 @@ class FeatureService:
@classmethod
def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
system_features = SystemFeatureModel()
system_features.app_dsl_version = CURRENT_APP_DSL_VERSION
cls._fulfill_system_params_from_env(system_features)

View File

@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,

View File

@ -476,7 +476,7 @@ class RagPipelineService:
:param filters: filter by node config parameters.
:return:
"""
node_type_enum = NodeType(node_type)
node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config

View File

@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator):
return TruncationResult(StringSegment(value=fallback_result.value), True)
# Apply final fallback - convert to JSON string and truncate
json_str = dumps_with_segments(result.value, ensure_ascii=False)
json_str = dumps_with_segments(result.value)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)

View File

@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
id=draft_var.id,
variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@ -1075,7 +1075,7 @@ class DraftVariableSaver:
filename = f"{self._generate_filename(name)}.txt"
else:
# For other types, store as JSON
original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
original_content_serialized = dumps_with_segments(value_seg.value)
content_type = "application/json"
filename = f"{self._generate_filename(name)}.json"

View File

@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly,
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.trigger.constants import TRIGGER_NODE_TYPES, is_trigger_node_type
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
normalize_human_input_node_data_for_graph,
adapt_human_input_node_data_for_graph,
parse_human_input_delivery_methods,
)
from core.workflow.node_factory import (
@ -112,7 +112,8 @@ class WorkflowService:
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize WorkflowService with repository dependencies."""
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
session_maker = sessionmaker(
bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
@ -216,7 +217,8 @@ class WorkflowService:
if not app_ids:
return set()
stmt = select(App.id).where(App.id.in_(app_ids), App.tenant_id == tenant_id)
stmt = select(App.id).where(App.id.in_(
app_ids), App.tenant_id == tenant_id)
return {str(app_id) for app_id in db.session.scalars(stmt).all()}
def get_all_published_workflow(
@ -300,7 +302,8 @@ class WorkflowService:
)
)
stmt = stmt.order_by(Workflow.created_at.desc()).limit(limit + 1).offset((page - 1) * limit)
stmt = stmt.order_by(Workflow.created_at.desc()).limit(
limit + 1).offset((page - 1) * limit)
workflows = session.scalars(stmt).all()
@ -332,7 +335,8 @@ class WorkflowService:
raise WorkflowHashNotEqualError()
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
self.validate_features_structure(
app_model=app_model, features=features)
# validate graph structure
self.validate_graph_structure(graph=graph)
@ -364,7 +368,8 @@ class WorkflowService:
db.session.commit()
# trigger app workflow events
app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow)
app_draft_workflow_was_synced.send(
app_model, synced_draft_workflow=workflow)
# return draft workflow
return workflow
@ -432,7 +437,8 @@ class WorkflowService:
raise ValueError("No draft workflow found.")
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
self.validate_features_structure(
app_model=app_model, features=features)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
@ -453,11 +459,13 @@ class WorkflowService:
Secret environment variables are copied server-side from the selected
published workflow so the normal draft sync flow stays stateless.
"""
source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
source_workflow = self.get_published_workflow_by_id(
app_model=app_model, workflow_id=workflow_id)
if not source_workflow:
raise WorkflowNotFoundError("Workflow not found.")
self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict)
self.validate_features_structure(
app_model=app_model, features=source_workflow.normalized_features_dict)
self.validate_graph_structure(graph=source_workflow.graph_dict)
draft_workflow = self.get_draft_workflow(app_model=app_model)
@ -474,7 +482,8 @@ class WorkflowService:
db.session.add(draft_workflow)
db.session.commit()
app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow)
app_draft_workflow_was_synced.send(
app_model, synced_draft_workflow=draft_workflow)
return draft_workflow
@ -518,7 +527,8 @@ class WorkflowService:
and is_trigger_node_type(node_type_str)
)
if trigger_node_count > 2:
raise TriggerNodeLimitExceededError(count=trigger_node_count, limit=2)
raise TriggerNodeLimitExceededError(
count=trigger_node_count, limit=2)
# create new workflow
workflow = Workflow.new(
@ -540,7 +550,8 @@ class WorkflowService:
session.add(workflow)
# trigger app workflow events
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
app_published_workflow_was_updated.send(
app_model, published_workflow=workflow)
# return new workflow
return workflow
@ -597,7 +608,8 @@ class WorkflowService:
session.add(workflow)
# trigger app workflow events
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
app_published_workflow_was_updated.send(
app_model, published_workflow=workflow)
return workflow
@ -615,7 +627,8 @@ class WorkflowService:
This endpoint only supports conversion between standard workflow and evaluation workflow.
"""
if target_type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}:
raise ValueError("target_type must be either 'workflow' or 'evaluation'")
raise ValueError(
"target_type must be either 'workflow' or 'evaluation'")
if not app_model.workflow_id:
raise WorkflowNotFoundError("Published workflow not found")
@ -630,7 +643,8 @@ class WorkflowService:
raise WorkflowNotFoundError("Published workflow not found")
if workflow.version == Workflow.VERSION_DRAFT:
raise IsDraftWorkflowError("Current effective workflow cannot be a draft version.")
raise IsDraftWorkflowError(
"Current effective workflow cannot be a draft version.")
if workflow.type == target_type:
return workflow
@ -642,7 +656,8 @@ class WorkflowService:
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
app_published_workflow_was_updated.send(
app_model, published_workflow=workflow)
return workflow
@ -660,7 +675,8 @@ class WorkflowService:
if not disallowed_nodes:
return
formatted_nodes = ", ".join(f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes)
formatted_nodes = ", ".join(
f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes)
raise ValueError(
"Evaluation workflow cannot contain trigger or human-input nodes. "
f"Found disallowed nodes: {formatted_nodes}"
@ -698,12 +714,14 @@ class WorkflowService:
)
else:
# Check default workspace credential for this provider
self._check_default_tool_credential(workflow.tenant_id, provider)
self._check_default_tool_credential(
workflow.tenant_id, provider)
elif node_type == "agent":
agent_params = node_data.get("agent_parameters", {})
model_config = agent_params.get("model", {}).get("value", {})
model_config = agent_params.get(
"model", {}).get("value", {})
if model_config.get("provider") and model_config.get("model"):
self._validate_llm_model_config(
workflow.tenant_id, model_config["provider"], model_config["model"]
@ -711,7 +729,8 @@ class WorkflowService:
# Validate load balancing credentials for agent model if load balancing is enabled
agent_model_node_data = {"model": model_config}
self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)
self._validate_load_balancing_credentials(
workflow, agent_model_node_data, node_id)
# Validate agent tools
tools = agent_params.get("tools", {}).get("value", [])
@ -723,9 +742,11 @@ class WorkflowService:
if credential_id:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
check_credential_policy_compliance(
credential_id, provider, PluginCredentialType.TOOL)
else:
self._check_default_tool_credential(workflow.tenant_id, provider)
self._check_default_tool_credential(
workflow.tenant_id, provider)
elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
model_config = node_data.get("model", {})
@ -734,11 +755,14 @@ class WorkflowService:
if provider and model_name:
# Validate that the provider+model combination can fetch valid credentials
self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
self._validate_llm_model_config(
workflow.tenant_id, provider, model_name)
# Validate load balancing credentials if load balancing is enabled
self._validate_load_balancing_credentials(workflow, node_data, node_id)
self._validate_load_balancing_credentials(
workflow, node_data, node_id)
else:
raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")
raise ValueError(
f"Node {node_id} ({node_type}): Missing provider or model configuration")
except Exception as e:
if isinstance(e, ValueError):
@ -780,8 +804,10 @@ class WorkflowService:
# If it fails, an exception will be raised
# Additionally, check the model status to ensure it's ACTIVE
provider_configurations = assembly.provider_manager.get_configurations(tenant_id)
models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM)
provider_configurations = assembly.provider_manager.get_configurations(
tenant_id)
models = provider_configurations.get_models(
provider=provider, model_type=ModelType.LLM)
target_model = None
for model in models:
@ -792,7 +818,8 @@ class WorkflowService:
if target_model:
target_model.raise_for_status()
else:
raise ValueError(f"Model {model_name} not found for provider {provider}")
raise ValueError(
f"Model {model_name} not found for provider {provider}")
except Exception as e:
raise ValueError(
@ -840,7 +867,8 @@ class WorkflowService:
)
except Exception as e:
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
raise ValueError(
f"Failed to validate default credential for tool provider {provider}: {str(e)}")
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None:
"""
@ -862,7 +890,8 @@ class WorkflowService:
# Check if this model has load balancing enabled
if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
# Get all load balancing configurations for this model
load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
load_balancing_configs = self._get_load_balancing_configs(
workflow.tenant_id, provider, model_name)
# Validate each load balancing configuration
try:
for config in load_balancing_configs:
@ -873,7 +902,8 @@ class WorkflowService:
config["credential_id"], provider, PluginCredentialType.MODEL
)
except Exception as e:
raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
raise ValueError(
f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
"""
@ -888,8 +918,10 @@ class WorkflowService:
from graphon.model_runtime.entities.model_entities import ModelType
# Get provider configurations
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_configurations = provider_manager.get_configurations(tenant_id)
provider_manager = create_plugin_provider_manager(
tenant_id=tenant_id)
provider_configurations = provider_manager.get_configurations(
tenant_id)
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
@ -930,7 +962,8 @@ class WorkflowService:
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
)
all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs)
all_configs = cast(list[dict[str, Any]], configs) + \
cast(list[dict[str, Any]], custom_configs)
return [config for config in all_configs if config.get("credential_id")]
@ -975,7 +1008,7 @@ class WorkflowService:
:param filters: filter by node config parameters.
:return:
"""
node_type_enum = NodeType(node_type)
node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
@ -994,7 +1027,8 @@ class WorkflowService:
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
)
default_config = node_class.get_default_config(filters=resolved_filters or None)
default_config = node_class.get_default_config(
filters=resolved_filters or None)
if not default_config:
return {}
@ -1017,7 +1051,8 @@ class WorkflowService:
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=account.id)
draft_var_srv.prefill_conversation_variable_default_values(
draft_workflow, user_id=account.id)
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
@ -1031,7 +1066,8 @@ class WorkflowService:
workflow=draft_workflow,
)
if node_type == BuiltinNodeTypes.START:
start_data = StartNodeData.model_validate(node_data, from_attributes=True)
start_data = StartNodeData.model_validate(
node_data, from_attributes=True)
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
@ -1066,7 +1102,8 @@ class WorkflowService:
user_id=account.id,
)
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(
node_config)
if enclosing_node_type_and_id:
_, enclosing_node_id = enclosing_node_type_and_id
else:
@ -1101,12 +1138,15 @@ class WorkflowService:
)
repository.save(node_execution)
workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id)
workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(
node_execution.id)
if workflow_node_execution is None:
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
raise ValueError(
f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with sessionmaker(db.engine).begin() as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
outputs = workflow_node_execution.load_full_outputs(
session, storage)
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
@ -1118,7 +1158,8 @@ class WorkflowService:
node_execution_id=node_execution.id,
user=account,
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
draft_var_saver.save(
process_data=node_execution.process_data, outputs=outputs)
enqueue_draft_node_execution_trace(
execution=workflow_node_execution,
@ -1245,7 +1286,8 @@ class WorkflowService:
rendered_content, outputs, node_data.outputs_field_names()
)
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(
node_config)
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
with sessionmaker(bind=db.engine).begin() as session:
draft_var_saver = DraftVariableSaver(
@ -1280,7 +1322,7 @@ class WorkflowService:
raise ValueError("Node type must be human-input.")
node_data = HumanInputNodeData.model_validate(
normalize_human_input_node_data_for_graph(node_config["data"]),
adapt_human_input_node_data_for_graph(node_config["data"]),
from_attributes=True,
)
delivery_method = self._resolve_human_input_delivery_method(
@ -1332,7 +1374,8 @@ class WorkflowService:
try:
test_service.send_test(context=context, method=delivery_method)
except DeliveryTestUnsupportedError as exc:
raise ValueError("Delivery method does not support test send.") from exc
raise ValueError(
"Delivery method does not support test send.") from exc
except DeliveryTestError as exc:
raise ValueError(str(exc)) from exc
@ -1357,7 +1400,8 @@ class WorkflowService:
rendered_content: str,
resolved_default_values: Mapping[str, Any],
) -> tuple[str, list[DeliveryTestEmailRecipient]]:
repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id)
repo = HumanInputFormRepositoryImpl(
tenant_id=app_model.tenant_id, app_id=app_model.id)
params = FormCreateParams(
workflow_execution_id=None,
node_id=node_id,
@ -1377,7 +1421,8 @@ class WorkflowService:
with Session(bind=db.engine) as session:
recipients = session.scalars(
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id)
select(HumanInputFormRecipient).where(
HumanInputFormRecipient.form_id == form_id)
).all()
recipients_data: list[DeliveryTestEmailRecipient] = []
for recipient in recipients:
@ -1388,11 +1433,13 @@ class WorkflowService:
try:
payload = json.loads(recipient.recipient_payload)
except (json.JSONDecodeError, ValueError):
logger.exception("Failed to parse human input recipient payload for delivery test.")
logger.exception(
"Failed to parse human input recipient payload for delivery test.")
continue
email = payload.get("email")
if email:
recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token))
recipients_data.append(DeliveryTestEmailRecipient(
email=email, form_token=recipient.access_token))
return recipients_data
def _build_human_input_node(
@ -1421,9 +1468,11 @@ class WorkflowService:
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
node_data = HumanInputNode.validate_node_data(
adapt_human_input_node_data_for_graph(node_config["data"]))
node = HumanInputNode(
id=node_config["id"],
config=node_config,
node_id=node_config["id"],
config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(run_context),
@ -1441,7 +1490,8 @@ class WorkflowService:
) -> VariablePool:
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id)
draft_var_srv.prefill_conversation_variable_default_values(
workflow, user_id=user_id)
variable_pool = VariablePool()
add_variables_to_pool(
@ -1519,7 +1569,8 @@ class WorkflowService:
Returns:
WorkflowNodeExecution: The execution result
"""
node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn)
node, node_run_result, run_succeeded, error = self._execute_node_safely(
invoke_node_fn)
# Create base node execution
node_execution = WorkflowNodeExecution(
@ -1535,7 +1586,8 @@ class WorkflowService:
)
# Populate execution result data
self._populate_execution_result(node_execution, node_run_result, run_succeeded, error)
self._populate_execution_result(
node_execution, node_run_result, run_succeeded, error)
return node_execution
@ -1564,7 +1616,8 @@ class WorkflowService:
# Apply error strategy if node failed
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.error_strategy:
node_run_result = self._apply_error_strategy(node, node_run_result)
node_run_result = self._apply_error_strategy(
node, node_run_result)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
@ -1595,7 +1648,8 @@ class WorkflowService:
status=WorkflowNodeExecutionStatus.EXCEPTION,
error=node_run_result.error,
inputs=node_run_result.inputs,
metadata={WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy},
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy},
outputs=error_outputs,
)
@ -1609,10 +1663,12 @@ class WorkflowService:
"""Populate node execution with result data."""
if run_succeeded and node_run_result:
node_execution.inputs = (
WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
WorkflowEntry.handle_special_values(
node_run_result.inputs) if node_run_result.inputs else None
)
node_execution.process_data = (
WorkflowEntry.handle_special_values(node_run_result.process_data)
WorkflowEntry.handle_special_values(
node_run_result.process_data)
if node_run_result.process_data
else None
)
@ -1641,7 +1697,8 @@ class WorkflowService:
workflow_converter = WorkflowConverter()
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
raise ValueError(
f"Current App mode: {app_model.mode} is not supported convert to workflow.")
# convert to workflow
new_app: App = workflow_converter.convert_to_workflow(
@ -1678,7 +1735,8 @@ class WorkflowService:
# start node and trigger node cannot coexist
if BuiltinNodeTypes.START in node_types:
if any(is_trigger_node_type(nt) for nt in node_types):
raise ValueError("Start node and trigger nodes cannot coexist in the same workflow")
raise ValueError(
"Start node and trigger nodes cannot coexist in the same workflow")
for node in node_configs:
node_data = node.get("data", {})
@ -1713,7 +1771,8 @@ class WorkflowService:
from graphon.nodes.human_input.entities import HumanInputNodeData
try:
HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data))
HumanInputNodeData.model_validate(
adapt_human_input_node_data_for_graph(node_data))
except Exception as e:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
@ -1730,7 +1789,8 @@ class WorkflowService:
:param data: Dictionary containing fields to update
:return: Updated workflow or None if not found
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
stmt = select(Workflow).where(Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
@ -1759,7 +1819,8 @@ class WorkflowService:
:raises: WorkflowInUseError if workflow is in use
:raises: DraftWorkflowDeletionError if workflow is a draft version
"""
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
stmt = select(Workflow).where(Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id)
workflow = session.scalar(stmt)
if not workflow:
@ -1767,14 +1828,16 @@ class WorkflowService:
# Check if workflow is a draft version
if workflow.version == Workflow.VERSION_DRAFT:
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
raise DraftWorkflowDeletionError(
"Cannot delete draft workflow versions")
# Check if this workflow is currently referenced by an app
app_stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(app_stmt)
if app:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
raise WorkflowInUseError(
f"Cannot delete workflow that is currently in use by app '{app.id}'")
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
# Check if there's a tool provider using this specific workflow version
@ -1788,7 +1851,8 @@ class WorkflowService:
if tool_provider:
# Cannot delete a workflow that's published as a tool
raise WorkflowInUseError("Cannot delete workflow that is published as a tool")
raise WorkflowInUseError(
"Cannot delete workflow that is published as a tool")
session.delete(workflow)
return True
@ -1837,11 +1901,13 @@ def _setup_variable_pool(
build_bootstrap_variables(
system_variables=system_variable,
environment_variables=workflow.environment_variables,
conversation_variables=cast(list[Variable], conversation_variables),
conversation_variables=cast(
list[Variable], conversation_variables),
),
)
if is_start_node_type(node_type):
add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs)
add_node_inputs_to_pool(
variable_pool, node_id=node_id, inputs=user_inputs)
return variable_pool
@ -1857,7 +1923,8 @@ def _rebuild_file_for_user_inputs_in_start_node(
if variable.variable not in user_inputs:
continue
value = user_inputs[variable.variable]
file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type)
file = _rebuild_single_file(
tenant_id=tenant_id, value=value, variable_entity_type=variable.type)
inputs_copy[variable.variable] = file
return inputs_copy
@ -1865,15 +1932,18 @@ def _rebuild_file_for_user_inputs_in_start_node(
def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]:
if variable_entity_type == VariableEntityType.FILE:
if not isinstance(value, dict):
raise ValueError(f"expected dict for file object, got {type(value)}")
raise ValueError(
f"expected dict for file object, got {type(value)}")
return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller)
elif variable_entity_type == VariableEntityType.FILE_LIST:
if not isinstance(value, list):
raise ValueError(f"expected list for file list object, got {type(value)}")
raise ValueError(
f"expected list for file list object, got {type(value)}")
if len(value) == 0:
return []
if not isinstance(value[0], dict):
raise ValueError(f"expected dict for first element in the file list, got {type(value)}")
raise ValueError(
f"expected dict for first element in the file list, got {type(value)}")
return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller)
else:
raise Exception("unreachable")

View File

@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod
from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod
from extensions.ext_database import db
from extensions.ext_mail import mail
from graphon.runtime import GraphRuntimeState, VariablePool

View File

@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamCompletedEvent
@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
id="n",
config={
"id": "n",
"data": {
"type": "datasource",
"version": "1",
"title": "Datasource",
"provider_type": "plugin",
"provider_name": "p",
"plugin_id": "plug",
"datasource_name": "ds",
},
},
node_id="n",
config=DatasourceNodeData(
type="datasource",
version="1",
title="Datasource",
provider_type="plugin",
provider_name="p",
plugin_id="plug",
datasource_name="ds",
),
graph_init_params=_GP(),
graph_runtime_state=_GS(vp),
)

View File

@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import NodeRunResult
from graphon.nodes.code.code_node import CodeNode
from graphon.nodes.code.entities import CodeNodeData
from graphon.nodes.code.limits import CodeNodeLimits
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@ -64,8 +65,8 @@ def init_code_node(code_config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = CodeNode(
id=str(uuid.uuid4()),
config=code_config,
node_id=str(uuid.uuid4()),
config=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,

View File

@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.graph import Graph
from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@ -75,8 +75,8 @@ def init_http_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
id=str(uuid.uuid4()),
config=graph_config["nodes"][1],
node_id=str(uuid.uuid4()),
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,

View File

@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode:
llm_file_saver = MagicMock(spec=LLMFileSaver)
node = LLMNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@ -11,6 +11,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),

View File

@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.template_rendering import TemplateRenderError
@ -86,8 +87,8 @@ def test_execute_template_transform():
assert graph is not None
node = TemplateTransformNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),

View File

@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.protocols import ToolFileManagerProtocol
from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool.tool_node import ToolNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@ -60,8 +61,8 @@ def init_tool_node(config: dict):
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
id=str(uuid.uuid4()),
config=config,
node_id=str(uuid.uuid4()),
config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,

View File

@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,

View File

@ -101,8 +101,8 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
node_id="start",
config=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@ -116,8 +116,8 @@ def _build_graph(
],
)
human_node = HumanInputNode(
id="human",
config={"id": "human", "data": human_data.model_dump()},
node_id="human",
config=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@ -130,8 +130,8 @@ def _build_graph(
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
node_id="end",
config=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)

View File

@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase):
file_related_id = related_id
return File(
id=str(uuid4()), # Generate new UUID for File.id
file_id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
type=FileType.DOCUMENT,
file_type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,

View File

@ -271,7 +271,7 @@ def _create_recipient(
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
from core.workflow.human_input_compat import DeliveryMethodType
from core.workflow.human_input_adapter import DeliveryMethodType
from models.human_input import ConsoleDeliveryPayload
delivery = HumanInputDelivery(

View File

@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@ -8,7 +8,7 @@ import pytest
from sqlalchemy.engine import Engine
from configs import dify_config
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
file_list = [
File(
tenant_id="t1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)

View File

@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url():
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
id="test_file_id",
type=FileType.IMAGE,
file_id="test_file_id",
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url():
# Create a File object with REMOTE_URL transfer method
test_file = File(
id="test_file_id",
type=FileType.IMAGE,
file_id="test_file_id",
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",

View File

@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import (
ConversationVariableUpdatePayload,
)
from controllers.service_api.app.error import NotChatAppError
from fields._value_type_serializer import serialize_value_type
from graphon.variables import StringSegment
from graphon.variables.types import SegmentType
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@ -284,6 +286,32 @@ class TestConversationVariableResponseModels:
assert response.created_at == int(created_at.timestamp())
assert response.updated_at == int(created_at.timestamp())
def test_variable_response_normalizes_string_value_type_alias(self):
response = ConversationVariableResponse.model_validate(
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SegmentType.INTEGER.value,
}
)
assert response.value_type == "number"
def test_variable_response_normalizes_callable_exposed_type(self):
response = ConversationVariableResponse.model_validate(
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "foo",
"value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()),
}
)
assert response.value_type == "string"
def test_serialize_value_type_supports_segments_and_mappings(self):
assert serialize_value_type(StringSegment(value="hello")) == "string"
assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number"
def test_variable_pagination_response(self):
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
{

View File

@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
def create_test_file(self, file_id: str = "test_file_1") -> File:
"""Create a test File object"""
return File(
id=file_id,
type=FileType.DOCUMENT,
file_id=file_id,
file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related_123",
filename=f"{file_id}.txt",

View File

@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow import node_factory as node_factory_module
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.base_node_data import BaseNodeData, RetryConfig
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.entities.pause_reason import SchedulingPause
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus
from graphon.graph import Graph
@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]):
def version(cls) -> str:
return "1"
def init_node_data(self, data):
self._node_data = _StubToolNodeData.model_validate(data)
def __init__(
self,
node_id: str,
config: _StubToolNodeData,
*,
graph_init_params,
graph_runtime_state,
**_kwargs: Any,
) -> None:
super().__init__(
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
def _get_error_strategy(self):
return self._node_data.error_strategy
@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]):
def _patch_tool_node(mocker):
original_create_node = DifyNodeFactory.create_node
original_resolve_node_class = node_factory_module.resolve_workflow_node_class
def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
node_data = typed_node_config["data"]
if node_data.type == BuiltinNodeTypes.TOOL:
return _StubToolNode(
id=str(typed_node_config["id"]),
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
return original_create_node(self, typed_node_config)
def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
if node_type == BuiltinNodeTypes.TOOL:
return _StubToolNode
return original_resolve_node_class(node_type=node_type, node_version=node_version)
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class)
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:

View File

@ -26,8 +26,8 @@ def _build_file(
extension: str | None = None,
) -> File:
return File(
id="file-id",
type=FileType.IMAGE,
file_id="file-id",
file_type=FileType.IMAGE,
transfer_method=transfer_method,
reference=reference,
remote_url=remote_url,
@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
assert runtime.multimodal_send_format == "url"
with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get:
with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
assert runtime.http_get("http://example", follow_redirects=False) == "response"
mock_get.assert_called_once_with("http://example", follow_redirects=False)

View File

@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
self.id = id
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
self.id = node_id
self.config = config
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state

View File

@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE)
built = File(
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tool_file_1",
extension=".png",
@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
file_in = File(
type=FileType.DOCUMENT,
file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tf",
extension=".pdf",

View File

@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
# Assert
assert simple_provider.provider == "openai"
assert simple_provider.label.en_US == "OpenAI"
assert simple_provider.label.en_us == "OpenAI"
assert simple_provider.supported_model_types == [ModelType.LLM]

View File

@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType
def test_file():
file = File(
id="test-file",
file_id="test-file",
tenant_id="test-tenant-id",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
@ -25,27 +25,21 @@ def test_file():
assert file.size == 67
def test_file_model_validate_accepts_legacy_tenant_id():
data = {
"id": "test-file",
"tenant_id": "test-tenant-id",
"type": "image",
"transfer_method": "tool_file",
"related_id": "test-related-id",
"filename": "image.png",
"extension": ".png",
"mime_type": "image/png",
"size": 67,
"storage_key": "test-storage-key",
"url": "https://example.com/image.png",
# Extra legacy fields
"tool_file_id": "tool-file-123",
"upload_file_id": "upload-file-456",
"datasource_file_id": "datasource-file-789",
}
def test_file_constructor_accepts_legacy_tenant_id():
file = File(
file_id="test-file",
tenant_id="test-tenant-id",
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
tool_file_id="tool-file-123",
filename="image.png",
extension=".png",
mime_type="image/png",
size=67,
storage_key="test-storage-key",
url="https://example.com/image.png",
)
file = File.model_validate(data)
assert file.related_id == "test-related-id"
assert file.related_id == "tool-file-123"
assert file.storage_key == "test-storage-key"
assert "tenant_id" not in file.model_dump()

View File

@ -1,11 +1,17 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
SSRFProxy,
_get_user_provided_host_header,
_to_graphon_http_response,
graphon_ssrf_proxy,
make_request,
max_retries_exceeded_error,
request_error,
)
@ -174,3 +180,56 @@ class TestFollowRedirectsParameter:
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
def test_to_graphon_http_response_preserves_httpx_response_fields() -> None:
response = httpx.Response(
201,
headers={"X-Test": "1"},
content=b"payload",
request=httpx.Request("GET", "https://example.com/resource"),
)
wrapped = _to_graphon_http_response(response)
assert wrapped.status_code == 201
assert wrapped.headers == {"x-test": "1", "content-length": "7"}
assert wrapped.content == b"payload"
assert wrapped.url == "https://example.com/resource"
assert wrapped.reason_phrase == "Created"
assert wrapped.text == "payload"
def test_ssrf_proxy_exposes_expected_error_types() -> None:
proxy = SSRFProxy()
assert proxy.max_retries_exceeded_error is max_retries_exceeded_error
assert proxy.request_error is request_error
assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error
assert graphon_ssrf_proxy.request_error is request_error
@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None:
response = httpx.Response(
200,
headers={"X-Test": "1"},
content=b"ok",
request=httpx.Request("GET", "https://example.com/resource"),
)
with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method:
wrapped = getattr(graphon_ssrf_proxy, method_name)(
"https://example.com/resource",
max_retries=3,
headers={"X-Test": "1"},
)
mock_method.assert_called_once_with(
url="https://example.com/resource",
max_retries=3,
headers={"X-Test": "1"},
)
assert wrapped.status_code == 200
assert wrapped.url == "https://example.com/resource"
assert wrapped.content == b"ok"

View File

@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderCredentialSchema,
ProviderEntity,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory

View File

@ -56,7 +56,7 @@ class TestPluginModelRuntime:
assert len(providers) == 1
assert providers[0].provider == "langgenius/openai/openai"
assert providers[0].provider_name == "openai"
assert providers[0].label.en_US == "OpenAI"
assert providers[0].label.en_us == "OpenAI"
client.fetch_model_providers.assert_called_once_with("tenant")
def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None:

View File

@ -466,7 +466,7 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
file_param = File(
tenant_id="tenant-1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/file.png",
storage_key="",
@ -499,14 +499,14 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
file_one = File(
tenant_id="tenant-1",
type=FileType.DOCUMENT,
file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/a.txt",
storage_key="",
)
file_two = File(
tenant_id="tenant-1",
type=FileType.DOCUMENT,
file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/b.txt",
storage_key="",

View File

@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
files = [
File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files():
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query():
memory = MagicMock(spec=TokenBufferMemory)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",

View File

@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import Conversation

View File

@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
# from graphon.model_runtime.entities.message_entities import UserPromptMessage
# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from graphon.model_runtime.entities.provider_entities import ProviderEntity
# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
# from core.prompt.prompt_transform import PromptTransform

View File

@ -1,12 +1,14 @@
"""Primarily used for testing merged cell scenarios"""
import gc
import io
import os
import tempfile
import warnings
from collections import UserDict
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest
from docx import Document
@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
WordExtractor("not-a-file", "tenant", "user")
def test_del_closes_temp_file():
def test_close_closes_temp_file():
extractor = object.__new__(WordExtractor)
extractor._closed = False
extractor.temp_file = MagicMock()
WordExtractor.__del__(extractor)
extractor.close()
extractor.temp_file.close.assert_called_once()
def test_close_is_idempotent():
extractor = object.__new__(WordExtractor)
extractor._closed = False
extractor.temp_file = MagicMock()
extractor.close()
extractor.close()
extractor.temp_file.close.assert_called_once()
def test_close_handles_async_close_mock():
extractor = object.__new__(WordExtractor)
extractor._closed = False
extractor.temp_file = MagicMock()
extractor.temp_file.close = AsyncMock()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
extractor.close()
gc.collect()
extractor.temp_file.close.assert_called_once()
assert not [
warning
for warning in caught
if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
]
def test_extract_images_handles_invalid_external_cases(monkeypatch):
class FakeTargetRef:
def __contains__(self, item):

View File

@ -14,7 +14,7 @@ from core.repositories.human_input_repository import (
HumanInputFormSubmissionRepository,
_WorkspaceMemberInfo,
)
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@ -21,7 +21,7 @@ from core.repositories.human_input_repository import (
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
from core.workflow.human_input_compat import (
from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,

View File

@ -6,9 +6,9 @@ from models.workflow import Workflow
def test_file_to_dict():
file = File(
id="file1",
file_id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",

View File

@ -1,8 +1,9 @@
import dataclasses
from typing import Annotated
import orjson
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Discriminator, Tag
from core.helper import encrypter
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
@ -12,17 +13,18 @@ from graphon.runtime import VariablePool
from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import (
ArrayAnySegment,
ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
SegmentUnion,
StringSegment,
get_segment_discriminator,
)
@ -47,6 +49,26 @@ from graphon.variables.variables import (
StringVariable,
Variable,
)
from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup
type SegmentUnion = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
| Annotated[StringSegment, Tag(SegmentType.STRING)]
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
),
Discriminator(get_segment_discriminator),
]
def _build_variable_pool(
@ -123,7 +145,7 @@ def create_test_file(
) -> File:
"""Factory function to create File objects for testing"""
return File(
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad:
assert restored == model
def test_all_segments_serialization(self):
"""Test serialization/deserialization of all segment types"""
"""Test file-aware segment serialization through Dify's model boundary."""
# Create one instance of each segment type
test_file = create_test_file()
@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
loaded = _Segments.model_validate_json(json_str)
loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
"""Test serialization/deserialization of all variable types"""
"""Test file-aware variable serialization through Dify's model boundary."""
# Create one instance of each variable type
test_file = create_test_file()
@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
loaded = _Variables.model_validate_json(json_str)
loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)

View File

@ -35,7 +35,7 @@ def create_test_file(
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
type=file_type,
file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,

View File

@ -1,12 +1,13 @@
"""
Mock node factory for testing workflows with third-party service dependencies.
"""Mock node factory for third-party-service workflow tests.
This module provides a MockNodeFactory that automatically detects and mocks nodes
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
The factory follows the same config adaptation path as production
`DifyNodeFactory.create_node()`, but swaps selected node classes for mock
implementations before instantiation.
"""
from typing import TYPE_CHECKING, Any
from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_factory import DifyNodeFactory
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes, NodeType
@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory):
:param node_config: Node configuration dictionary
:return: Node instance (real or mocked)
"""
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_type = node_data.type
# Check if this node type should be mocked
if node_type in self._mock_node_types:
node_id = typed_node_config["id"]
# Create mock node instance
mock_class = self._mock_node_types[node_type]
resolved_node_data = self._validate_resolved_node_data(mock_class, node_data)
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory):
)
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory):
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
}:
mock_instance = mock_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory):
)
else:
mock_instance = mock_class(
id=node_id,
config=typed_node_config,
node_id=node_id,
config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory):
return mock_instance
# For non-mocked node types, use parent implementation
return super().create_node(typed_node_config)
return super().create_node(node_config)
def should_mock_node(self, node_type: NodeType) -> bool:
"""

View File

@ -55,13 +55,14 @@ class MockNodeMixin:
def __init__(
self,
id: str,
config: Mapping[str, Any],
node_id: str,
config: Any,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
):
) -> None:
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
@ -96,7 +97,7 @@ class MockNodeMixin:
kwargs.setdefault("message_transformer", MagicMock())
super().__init__(
id=id,
node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

Some files were not shown because too many files have changed in this diff Show More