diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index 82d9c21cbb..5884830569 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -20,11 +20,11 @@ ```typescript // ❌ WRONG: Don't mock base components vi.mock('@/app/components/base/loading', () => () =>
Loading
) -vi.mock('@/app/components/base/ui/button', () => ({ children }: any) => ) +vi.mock('@langgenius/dify-ui/button', () => ({ children }: any) => ) // ✅ CORRECT: Import and use real base components import Loading from '@/app/components/base/loading' -import { Button } from '@/app/components/base/ui/button' +import { Button } from '@langgenius/dify-ui/button' // They will render normally in tests ``` diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index eefb1ebbb9..c55b013dbe 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -76,13 +76,11 @@ jobs: diff += '\\n\\n... (truncated) ...'; } - const body = diff.trim() - ? '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
' - : '### Pyrefly Diff\nNo changes detected.'; - - await github.rest.issues.createComment({ - issue_number: prNumber, - owner: context.repo.owner, - repo: context.repo.repo, - body, - }); + if (diff.trim()) { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body: '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
', + }); + } diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index f3ab4c62c7..2a5cf19645 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -89,3 +89,37 @@ jobs: flags: web env: CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} + + dify-ui-test: + name: dify-ui Tests + runs-on: ubuntu-latest + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + defaults: + run: + shell: bash + working-directory: ./packages/dify-ui + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web environment + uses: ./.github/actions/setup-web + + - name: Install Chromium for Browser Mode + run: vp exec playwright install --with-deps chromium + + - name: Run dify-ui tests + run: vp test run --coverage --silent=passed-only + + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 + with: + directory: packages/dify-ui/coverage + flags: dify-ui + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} diff --git a/api/constants/dsl_version.py b/api/constants/dsl_version.py new file mode 100644 index 0000000000..b0fbe0075c --- /dev/null +++ b/api/constants/dsl_version.py @@ -0,0 +1 @@ +CURRENT_APP_DSL_VERSION = "0.6.0" diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index cead33d14f..9c8b095b9f 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel): def _normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): return value try: diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index f6319573e0..e32ba5f66c 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable): def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: value_type = workflow_draft_var.value_type - return value_type.exposed_type().value + return str(value_type.exposed_type()) class FullContentDict(TypedDict): @@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict result: FullContentDict = { "size_bytes": variable_file.size, - "value_type": variable_file.value_type.exposed_type().value, + "value_type": str(variable_file.value_type.exposed_type()), "length": variable_file.length, "download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True), } @@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.exposed_type().value, + "value_type": str(v.value_type.exposed_type()), "value": v.value, # Do not track edited for env vars. "edited": False, diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c4353ca7b8..ca4b18cb5e 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel): def normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): try: - return str(SegmentType(value).exposed_type().value) + return str(SegmentType(value).exposed_type()) except ValueError: return value try: diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 790602ef5d..c22102c2ba 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -42,7 +42,7 @@ from graphon.model_runtime.entities import ( ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index dbd7527fc6..5df3df2b3e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 09ddce327e..cae0eee0df 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -18,7 +18,7 @@ from core.moderation.base import ModerationError from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index dfe6133cb6..e2e07ebaff 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, TextPromptMessageContent, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 68e5e5f0c8..3a6f9d575a 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage from graphon.file import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.protocols import WorkflowFileRuntimeProtocol from graphon.file.runtime import set_workflow_file_runtime +from graphon.http.protocols import HttpResponseProtocol if TYPE_CHECKING: from graphon.file import File @@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): return dify_config.MULTIMODAL_SEND_FORMAT def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - return ssrf_proxy.get(url, follow_redirects=follow_redirects) + return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 87f005a250..d521304615 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs - execution.exceptions_count = runtime_state.exceptions_count + execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count) def _update_node_execution( self, diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index dc831e5cac..f0dcb13b62 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -352,11 +352,11 @@ class DatasourceManager: raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") file_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.CUSTOM, + file_type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference(record_id=str(upload_file.id)), diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 6bbf163c9d..38b87e2cd1 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import ( FormType, ProviderEntity, ) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.base.ai_model import AIModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now @@ -363,7 +363,7 @@ class ProviderConfiguration(BaseModel): ) for key, value in validated_credentials.items(): - if key in provider_credential_secret_variables: + if key in provider_credential_secret_variables and isinstance(value, str): validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return validated_credentials @@ -912,7 +912,7 @@ class ProviderConfiguration(BaseModel): ) for key, value in validated_credentials.items(): - if key in provider_credential_secret_variables: + if key in provider_credential_secret_variables and isinstance(value, str): validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) return validated_credentials diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b96a9ce380..38864a1830 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -102,7 +102,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() + inputs_json_str = dumps_with_segments(inputs).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index dc37a36943..f169f247cf 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_ from extensions.ext_hosting_provider import hosting_configuration from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index e38592bb7b..91e92712b7 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client from core.tools.errors import ToolSSRFError +from graphon.http.response import HttpResponse logger = logging.getLogger(__name__) @@ -267,4 +268,47 @@ class SSRFProxy: return patch(url=url, max_retries=max_retries, **kwargs) +def _to_graphon_http_response(response: httpx.Response) -> HttpResponse: + """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper.""" + return HttpResponse( + status_code=response.status_code, + headers=dict(response.headers), + content=response.content, + url=str(response.url) if response.url else None, + reason_phrase=response.reason_phrase, + fallback_text=response.text, + ) + + +class GraphonSSRFProxy: + """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``.""" + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs)) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs)) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs)) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs)) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs)) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs)) + + ssrf_proxy = SSRFProxy() +graphon_ssrf_proxy = GraphonSSRFProxy() diff --git a/api/core/model_manager.py b/api/core/model_manager.py index d8d8dfedd8..86d0e3baaa 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence -from typing import IO, Any, Literal, Optional, Union, cast, overload +from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from configs import dify_config from core.entities import PluginCredentialType @@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from models.provider import ProviderType logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") class ModelInstance: @@ -168,7 +170,7 @@ class ModelInstance: return cast( Union[LLMResult, Generator], self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -193,7 +195,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -213,7 +215,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, texts=texts, @@ -235,7 +237,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, @@ -252,7 +254,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, texts=texts, @@ -277,7 +279,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, query=query, @@ -305,7 +307,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke_multimodal_rerank, + self.model_type_instance.invoke_multimodal_rerank, model=self.model_name, credentials=self.credentials, query=query, @@ -324,7 +326,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, text=text, @@ -340,7 +342,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, file=file, @@ -357,14 +359,14 @@ class ModelInstance: if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, content_text=content_text, voice=voice, ) - def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Round-robin invoke :param function: function to invoke diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index e3fba4ef3a..4e66d58b5e 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime): if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") file_name = ( - provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us ) elif icon_type.lower() == "icon_small_dark": if not provider_schema.icon_small_dark: raise ValueError(f"Provider {provider} does not have small dark icon.") file_name = ( - provider_schema.icon_small_dark.zh_Hans + provider_schema.icon_small_dark.zh_hans if lang.lower() == "zh_hans" - else provider_schema.icon_small_dark.en_US + else provider_schema.icon_small_dark.en_us ) else: raise ValueError(f"Unsupported icon type: {icon_type}.") diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 8f1d51f08a..7c6280fe93 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 4926f44f16..a9995778f7 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 052fca930d..0330a43b28 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -3,6 +3,7 @@ Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). """ +import inspect import logging import mimetypes import os @@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor): file_path: Path to the file to load. """ + _closed: bool + def __init__(self, file_path: str, tenant_id: str, user_id: str): """Initialize with file path.""" + self._closed = False self.file_path = file_path self.tenant_id = tenant_id self.user_id = user_id @@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor): elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") + def close(self) -> None: + """Best-effort cleanup for downloaded temporary files.""" + if getattr(self, "_closed", False): + return + + self._closed = True + temp_file = getattr(self, "temp_file", None) + if temp_file is None: + return + + try: + close_result = temp_file.close() + if inspect.isawaitable(close_result): + close_awaitable = getattr(close_result, "close", None) + if callable(close_awaitable): + close_awaitable() + except Exception: + logger.debug("Failed to cleanup downloaded word temp file", exc_info=True) + def __del__(self): - if hasattr(self, "temp_file"): - self.temp_file.close() + self.close() def extract(self) -> list[Document]: """Load given path as single page.""" diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index f8242efe31..7ffa9afafd 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): try: # Create File object directly (similar to DatasetRetrieval) file_obj = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 1453fe020b..5631b3a921 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile @@ -517,11 +517,11 @@ class DatasetRetrieval: if attachments_with_bindings: for _, upload_file in attachments_with_bindings: attachment_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 2581c354dd..66b375dad1 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -9,7 +9,7 @@ from typing import Any, Literal from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer +from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 02625e242f..740d727e26 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -8,7 +8,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index b3424cd9a5..c87e8a3ae0 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -28,7 +28,7 @@ class ToolFileManager: def _build_graph_file_reference(tool_file: ToolFile) -> File: extension = guess_extension(tool_file.mimetype) or ".bin" return File( - type=get_file_type_by_mime_type(tool_file.mimetype), + file_type=get_file_type_by_mime_type(tool_file.mimetype), transfer_method=FileTransferMethod.TOOL_FILE, remote_url=tool_file.original_url, reference=build_file_reference(record_id=str(tool_file.id)), diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f4588904d3..87cf6d7085 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1082,7 +1082,12 @@ class ToolManager: continue tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + variable_selector = tool_input.value + if not isinstance(variable_selector, list) or not all( + isinstance(selector_part, str) for selector_part in variable_selector + ): + raise ToolParameterError("Variable tool input must be a variable selector") + variable = variable_pool.get(variable_selector) if variable is None: raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e1d41cb39..a3623d4ecd 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ToolModelInvoke diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 52ab605963..cd8c6352b5 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -357,7 +357,10 @@ class WorkflowTool(Tool): def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]: file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) - transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) + transfer_method_value = file_dict.get("transfer_method") + if not isinstance(transfer_method_value, str): + raise ValueError("Workflow file mapping is missing a valid transfer_method") + transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.TOOL_FILE: file_dict["tool_file_id"] = file_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py similarity index 74% rename from api/core/workflow/human_input_compat.py rename to api/core/workflow/human_input_adapter.py index 75a0a0c202..4b765e6aea 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_adapter.py @@ -1,8 +1,8 @@ -"""Workflow-layer adapters for legacy human-input payload keys. +"""Workflow-to-Graphon adapters for persisted node payloads. -Stored workflow graphs and editor payloads may still use Dify-specific human -input recipient keys. Normalize them here before handing configs to -`graphon` so graph-owned models only see graph-neutral field names. +Stored workflow graphs and editor payloads still contain a small set of +Dify-owned field spellings and value shapes. Adapt them here before handing the +payload to Graphon so Graphon-owned models only see current contracts. """ from __future__ import annotations @@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None: return None -def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") @@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: - normalized = normalize_human_input_node_data_for_graph(node_data) + normalized = adapt_human_input_node_data_for_graph(node_data) raw_delivery_methods = normalized.get("delivery_methods") if not isinstance(raw_delivery_methods, list): return [] @@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b return False -def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") - if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: - return normalized - return normalize_human_input_node_data_for_graph(normalized) + node_type = normalized.get("type") + if node_type == BuiltinNodeTypes.HUMAN_INPUT: + return adapt_human_input_node_data_for_graph(normalized) + if node_type == BuiltinNodeTypes.TOOL: + return _adapt_tool_node_data_for_graph(normalized) + return normalized -def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_config) if normalized is None: raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") @@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) if data_mapping is None: return normalized - normalized["data"] = normalize_node_data_for_graph(data_mapping) + normalized["data"] = adapt_node_data_for_graph(data_mapping) return normalized +def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(node_data) + + raw_tool_configurations = normalized.get("tool_configurations") + if not isinstance(raw_tool_configurations, Mapping): + return normalized + + existing_tool_parameters = normalized.get("tool_parameters") + normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {} + normalized_tool_configurations: dict[str, Any] = {} + found_legacy_tool_inputs = False + + for name, value in raw_tool_configurations.items(): + if not isinstance(value, Mapping): + normalized_tool_configurations[name] = value + continue + + input_type = value.get("type") + input_value = value.get("value") + if input_type not in {"mixed", "variable", "constant"}: + normalized_tool_configurations[name] = value + continue + + found_legacy_tool_inputs = True + normalized_tool_parameters.setdefault(name, dict(value)) + + flattened_value = _flatten_legacy_tool_configuration_value( + input_type=input_type, + input_value=input_value, + ) + if flattened_value is not None: + normalized_tool_configurations[name] = flattened_value + + if not found_legacy_tool_inputs: + return normalized + + normalized["tool_parameters"] = normalized_tool_parameters + normalized["tool_configurations"] = normalized_tool_configurations + return normalized + + +def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None: + if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool): + return input_value + + if ( + input_type == "variable" + and isinstance(input_value, list) + and all(isinstance(item, str) for item in input_value) + ): + return "{{#" + ".".join(input_value) + "#}}" + + return None + + def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: normalized = dict(recipients) @@ -291,9 +349,9 @@ __all__ = [ "MemberRecipient", "WebAppDeliveryMethod", "_WebAppDeliveryConfig", + "adapt_human_input_node_data_for_graph", + "adapt_node_config_for_graph", + "adapt_node_data_for_graph", "is_human_input_webapp_enabled", - "normalize_human_input_node_data_for_graph", - "normalize_node_config_for_graph", - "normalize_node_data_for_graph", "parse_human_input_delivery_methods", ] diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 351da3444f..de4eae1b22 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, ) -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.trigger.constants import TRIGGER_NODE_TYPES -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.node_runtime import ( DifyFileReferenceFactory, DifyHumanInputNodeRuntime, @@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType from graphon.file.file_manager import file_manager from graphon.graph.graph import NodeFactory from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.nodes.base.node import Node from graphon.nodes.code.code_node import WorkflowCodeExecutor from graphon.nodes.code.entities import CodeLanguage @@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + """Resolve the production node class for the requested type/version.""" node_mapping = get_node_type_classes_mapping().get(node_type) if not node_mapping: raise ValueError(f"No class mapping found for node type: {node_type}") @@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory): ) self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH - self._http_request_http_client = ssrf_proxy + self._http_request_http_client = graphon_ssrf_proxy self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( self._dify_context, conversation_id_getter=self._conversation_id, @@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + # Graph configs are initially validated against permissive shared node data. + # Re-validate using the resolved node class so workflow-local node schemas + # stay explicit and constructors receive the concrete typed payload. + resolved_node_data = self._validate_resolved_node_data(node_class, node_data) node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { BuiltinNodeTypes.CODE: lambda: { @@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory): ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=False, include_llm_file_saver=False, @@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory): } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, **node_init_kwargs, @@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory): """ Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. """ - return node_class.validate_node_data(node_data) + validate_node_data = getattr(node_class, "validate_node_data", None) + if callable(validate_node_data): + return cast("BaseNodeData", validate_node_data(node_data)) + return node_data @staticmethod def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 2e632e56f0..b8725853c4 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload from sqlalchemy import select from sqlalchemy.orm import Session @@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import ( ) from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.llm.runtime_protocols import ( PreparedLLMProtocol, @@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from .human_input_compat import ( +from .human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, DeliveryMethodType, @@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: return self._model_instance.get_llm_num_tokens(prompt_messages) + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): stream=stream, ) + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + def invoke_llm_with_structured_output( self, *, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7b000101b0..68a24e86b1 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node @@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: AgentNodeData, + *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, - *, strategy_resolver: AgentStrategyResolver, presentation_provider: AgentStrategyPresentationProvider, runtime_support: AgentRuntimeSupport, message_transformer: AgentMessageTransformer, ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index e4f6b3b470..f3006c4242 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import SystemVariableKey, get_system_segment -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, @@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: DatasourceNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d5cab05dbe..9c1b7ab2c4 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node @@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeIndexNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: - super().__init__(id, config, graph_init_params, graph_runtime_state) + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) self.index_processor = IndexProcessor() self.summary_index_service = SummaryIndex() diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 47ad14b499..25f73e446d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict, from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, @@ -50,6 +49,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None: + if value is None or isinstance(value, (str, float)): + return value + if isinstance(value, int) and not isinstance(value, bool): + return value + return str(value) + + +def _normalize_metadata_filter_sequence_item(value: object) -> str: + return value if isinstance(value, str) else str(value) + + class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL @@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeRetrievalNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD resolved_conditions: list[Condition] = [] for cond in conditions.conditions or []: value = cond.value + resolved_value: str | Sequence[str] | int | float | None if isinstance(value, str): segment_group = variable_pool.convert_template(value) if len(segment_group.value) == 1: - resolved_value = segment_group.value[0].to_object() + resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object()) else: resolved_value = segment_group.text elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): - resolved_values = [] - for v in value: # type: ignore + resolved_values: list[str] = [] + for v in value: segment_group = variable_pool.convert_template(v) if len(segment_group.value) == 1: - resolved_values.append(segment_group.value[0].to_object()) + resolved_values.append( + _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object()) + ) else: resolved_values.append(segment_group.text) resolved_value = resolved_values diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index ce1fa441c2..1d2ad4d445 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -148,11 +148,11 @@ def _build_from_local_file( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=row.name, extension="." + row.extension, mime_type=row.mime_type, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=row.source_url, reference=build_file_reference(record_id=str(row.id)), @@ -196,11 +196,11 @@ def _build_from_remote_url( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), reference=build_file_reference(record_id=str(upload_file.id)), @@ -222,9 +222,9 @@ def _build_from_remote_url( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=filename, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -263,9 +263,9 @@ def _build_from_tool_file( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=tool_file.name, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=tool_file.original_url, reference=build_file_reference(record_id=str(tool_file.id)), @@ -306,9 +306,9 @@ def _build_from_datasource_file( ) return File( - id=mapping.get("datasource_file_id"), + file_id=mapping.get("datasource_file_id"), filename=datasource_file.name, - type=file_type, + file_type=file_type, transfer_method=FileTransferMethod.TOOL_FILE, remote_url=datasource_file.source_url, reference=build_file_reference(record_id=str(datasource_file.id)), diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index b5acbbbcb4..d518114777 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False): def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): - return v.value_type.exposed_type().value + return str(v.value_type.exposed_type()) else: value_type = v.get("value_type") if value_type is None: raise ValueError("value_type is required but not provided") - return value_type.exposed_type().value + return str(value_type.exposed_type()) diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index cf4a71d545..e4219ba1ee 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel): def _normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): try: - return str(SegmentType(value).exposed_type().value) + return str(SegmentType(value).exposed_type()) except ValueError: return value try: diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f9b5e98936..6e947858ba 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.exposed_type().value, + "value_type": str(value.value_type.exposed_type()), "description": value.description, } if isinstance(value, dict): diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 9b53918f24..934aacb45b 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -6,8 +6,8 @@ from flask_login import current_user from pydantic import TypeAdapter from sqlalchemy import select +from core.db.session_factory import session_factory from core.helper.http_client_pooling import get_pooled_http_client -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) @@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def sync_data_source(self, binding_id: str) -> None: # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, + ) ) - ) - if data_source_binding: - # get all authorized pages - pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) - new_source_info = self._build_source_info( - workspace_name=source_info["workspace_name"], - workspace_icon=source_info["workspace_icon"], - workspace_id=source_info["workspace_id"], - pages=pages, - ) - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - raise ValueError("Data source binding not found") + if data_source_binding: + # get all authorized pages + pages = self.get_authorized_pages(data_source_binding.access_token) + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: pages: list[NotionPageSummary] = [] diff --git a/api/models/comment.py b/api/models/comment.py index 7018c7e1f2..a30a52ec86 100644 --- a/api/models/comment.py +++ b/api/models/comment.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import Optional +import sqlalchemy as sa from sqlalchemy import Index, func from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -36,24 +37,27 @@ class WorkflowComment(Base): __tablename__ = "workflow_comments" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), Index("workflow_comments_app_idx", "tenant_id", "app_id"), Index("workflow_comments_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column( + StringUUID, server_default=sa.text("uuidv7()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - position_x: Mapped[float] = mapped_column(db.Float) - position_y: Mapped[float] = mapped_column(db.Float) - content: Mapped[str] = mapped_column(db.Text, nullable=False) + position_x: Mapped[float] = mapped_column(sa.Float) + position_y: Mapped[float] = mapped_column(sa.Float) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime) + resolved: Mapped[bool] = mapped_column( + sa.Boolean, nullable=False, server_default=sa.text("false")) + resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime) resolved_by: Mapped[str | None] = mapped_column(StringUUID) # Relationships @@ -143,23 +147,22 @@ class WorkflowCommentReply(Base): __tablename__ = "workflow_comment_replies" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), Index("comment_replies_comment_idx", "comment_id"), Index("comment_replies_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) - comment_id: Mapped[str] = mapped_column( - StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False - ) - content: Mapped[str] = mapped_column(db.Text, nullable=False) + id: Mapped[str] = mapped_column( + StringUUID, server_default=sa.text("uuidv7()")) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) # Relationships - comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + comment: Mapped["WorkflowComment"] = relationship( + "WorkflowComment", back_populates="replies") @property def created_by_account(self): @@ -187,24 +190,23 @@ class WorkflowCommentMention(Base): __tablename__ = "workflow_comment_mentions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), Index("comment_mentions_comment_idx", "comment_id"), Index("comment_mentions_reply_idx", "reply_id"), Index("comment_mentions_user_idx", "mentioned_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column( + StringUUID, server_default=sa.text("uuidv7()")) comment_id: Mapped[str] = mapped_column( - StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False - ) - reply_id: Mapped[str | None] = mapped_column( - StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - # Relationships - comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") - reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + comment: Mapped["WorkflowComment"] = relationship( + "WorkflowComment", back_populates="mentions") + reply: Mapped[Optional["WorkflowCommentReply"] + ] = relationship("WorkflowCommentReply") @property def mentioned_user_account(self): diff --git a/api/models/human_input.py b/api/models/human_input.py index b4c7a634b6..7447d3efcb 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -6,7 +6,7 @@ import sqlalchemy as sa from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from core.workflow.human_input_compat import DeliveryMethodType +from core.workflow.human_input_adapter import DeliveryMethodType from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index a2dc8f6157..77dcbd13d4 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -5,7 +5,8 @@ from functools import lru_cache from typing import Any from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod +from graphon.file import File, FileTransferMethod, FileType +from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object @lru_cache(maxsize=1) @@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id( return tenant_resolver() +def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File: + """Build a graph `File` directly from serialized metadata.""" + + def _coerce_file_type(value: Any) -> FileType: + if isinstance(value, FileType): + return value + if isinstance(value, str): + return FileType.value_of(value) + raise ValueError("file type is required in file mapping") + + mapping = dict(file_mapping) + transfer_method_value = mapping.get("transfer_method") + if isinstance(transfer_method_value, FileTransferMethod): + transfer_method = transfer_method_value + elif isinstance(transfer_method_value, str): + transfer_method = FileTransferMethod.value_of(transfer_method_value) + else: + raise ValueError("transfer_method is required in file mapping") + + file_id = mapping.get("file_id") + if not isinstance(file_id, str) or not file_id: + legacy_id = mapping.get("id") + file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None + + related_id = resolve_file_record_id(mapping) + if related_id is None: + raw_related_id = mapping.get("related_id") + related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None + + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + remote_url = url if isinstance(url, str) and url else None + + reference = mapping.get("reference") + if not isinstance(reference, str) or not reference: + reference = None + + filename = mapping.get("filename") + if not isinstance(filename, str): + filename = None + + extension = mapping.get("extension") + if not isinstance(extension, str): + extension = None + + mime_type = mapping.get("mime_type") + if not isinstance(mime_type, str): + mime_type = None + + size = mapping.get("size", -1) + if not isinstance(size, int): + size = -1 + + storage_key = mapping.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + tenant_id = mapping.get("tenant_id") + if not isinstance(tenant_id, str): + tenant_id = None + + dify_model_identity = mapping.get("dify_model_identity") + if not isinstance(dify_model_identity, str): + dify_model_identity = FILE_MODEL_IDENTITY + + tool_file_id = mapping.get("tool_file_id") + if not isinstance(tool_file_id, str): + tool_file_id = None + + upload_file_id = mapping.get("upload_file_id") + if not isinstance(upload_file_id, str): + upload_file_id = None + + datasource_file_id = mapping.get("datasource_file_id") + if not isinstance(datasource_file_id, str): + datasource_file_id = None + + return File( + file_id=file_id, + tenant_id=tenant_id, + file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))), + transfer_method=transfer_method, + remote_url=remote_url, + reference=reference, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + storage_key=storage_key, + dify_model_identity=dify_model_identity, + url=remote_url, + tool_file_id=tool_file_id, + upload_file_id=upload_file_id, + datasource_file_id=datasource_file_id, + ) + + +def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any: + """Recursively rebuild serialized graph file payloads into `File` objects. + + `graphon` 0.2.2 no longer accepts legacy serialized file mappings via + `model_validate_json()`. Dify keeps this recovery path at the model boundary + so historical JSON blobs remain readable without reintroducing global graph + patches or test-local coercion. + """ + if isinstance(value, list): + return [rebuild_serialized_graph_files_without_lookup(item) for item in value] + + if isinstance(value, dict): + if maybe_file_object(value): + return build_file_from_mapping_without_lookup(file_mapping=value) + return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()} + + return value + + def build_file_from_stored_mapping( *, file_mapping: Mapping[str, Any], @@ -76,12 +195,7 @@ def build_file_from_stored_mapping( pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: - remote_url = mapping.get("remote_url") - if not isinstance(remote_url, str) or not remote_url: - url = mapping.get("url") - if isinstance(url, str) and url: - mapping["remote_url"] = url - return File.model_validate(mapping) + return build_file_from_mapping_without_lookup(file_mapping=mapping) return file_factory.build_from_mapping( mapping=mapping, diff --git a/api/models/workflow.py b/api/models/workflow.py index 467d60f6ac..d25fa488e8 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, @@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID -from .utils.file_input_compat import build_file_from_stored_mapping +from .utils.file_input_compat import ( + build_file_from_mapping_without_lookup, + build_file_from_stored_mapping, +) logger = logging.getLogger(__name__) @@ -292,7 +295,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -1690,7 +1693,7 @@ class WorkflowDraftVariable(Base): return cast(Any, value) normalized_file = dict(value) normalized_file.pop("tenant_id", None) - return File.model_validate(normalized_file) + return build_file_from_mapping_without_lookup(file_mapping=normalized_file) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] @@ -1700,7 +1703,7 @@ class WorkflowDraftVariable(Base): for item in value_list: normalized_file = dict(cast(dict[str, Any], item)) normalized_file.pop("tenant_id", None) - file_list.append(File.model_validate(normalized_file)) + file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py index cfcf6b307e..a8c480e4a5 100644 --- a/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py @@ -1,7 +1,6 @@ -""" -Tencent APM tracing implementation with separated concerns -""" +"""Tencent APM tracing with idempotent client cleanup.""" +import inspect import logging from sqlalchemy import select @@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance): """ Tencent APM trace implementation with single responsibility principle. Acts as a coordinator that delegates specific tasks to specialized classes. + + The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen + explicitly in tests or implicitly during garbage collection, so shutdown + must be safe to call multiple times. """ + trace_client: TencentTraceClient + _closed: bool + def __init__(self, tencent_config: TencentConfig): super().__init__(tencent_config) + self._closed = False self.trace_client = TencentTraceClient( service_name=tencent_config.service_name, endpoint=tencent_config.endpoint, @@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance): except Exception: logger.debug("[Tencent APM] Failed to record message trace duration") - def __del__(self): - """Ensure proper cleanup on garbage collection.""" + def close(self) -> None: + """Synchronously and idempotently shutdown the underlying trace client.""" + if getattr(self, "_closed", False): + return + + self._closed = True + trace_client = getattr(self, "trace_client", None) + if trace_client is None: + return + try: - if hasattr(self, "trace_client"): - self.trace_client.shutdown() + shutdown_result = trace_client.shutdown() + if inspect.isawaitable(shutdown_result): + close_awaitable = getattr(shutdown_result, "close", None) + if callable(close_awaitable): + close_awaitable() except Exception: logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup") + + def __del__(self): + """Ensure best-effort cleanup on garbage collection without retrying shutdown.""" + self.close() diff --git a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index a91a0aa558..54524b09ca 100644 --- a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -1,5 +1,7 @@ +import gc import logging -from unittest.mock import MagicMock, patch +import warnings +from unittest.mock import AsyncMock, MagicMock, patch import pytest from dify_trace_tencent.config import TencentConfig @@ -632,13 +634,38 @@ class TestTencentDataTrace: with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_trace_duration(trace_info) - def test_del(self, tencent_data_trace): + def test_close(self, tencent_data_trace): client = tencent_data_trace.trace_client - tencent_data_trace.__del__() + tencent_data_trace.close() client.shutdown.assert_called_once() - def test_del_exception(self, tencent_data_trace): + def test_close_is_idempotent(self, tencent_data_trace): + client = tencent_data_trace.trace_client + + tencent_data_trace.close() + tencent_data_trace.close() + + client.shutdown.assert_called_once() + + def test_close_exception(self, tencent_data_trace): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.__del__() + tencent_data_trace.close() mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") + + def test_close_handles_async_shutdown_mock(self, tencent_data_trace): + shutdown = AsyncMock() + tencent_data_trace.trace_client.shutdown = shutdown + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + tencent_data_trace.close() + gc.collect() + + shutdown.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) + and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] diff --git a/api/pyproject.toml b/api/pyproject.toml index a21d98f7bf..c349979a7c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ # Emerging: newer and fast-moving, use compatible pins "fastopenapi[flask]~=0.7.0", - "graphon~=0.1.2", + "graphon~=0.2.2", "httpx-sse~=0.4.0", "json-repair~=0.59.2", ] @@ -173,7 +173,7 @@ dev = [ # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.60.0", + "pyrefly>=0.61.1", "xinference-client>=2.4.0", ] diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 8479cdfb0c..2cc0192a4a 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -7,8 +7,8 @@ from sqlalchemy import select import app from configs import dify_config +from core.db.session_factory import session_factory from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service from models import Account, Tenant, TenantAccountJoin @@ -33,67 +33,68 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = db.session.scalars( - select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) - ).all() - # group by tenant_id - dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) - for dataset_auto_disable_log in dataset_auto_disable_logs: - if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) - url = f"{dify_config.CONSOLE_WEB_URL}/datasets" - for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): - features = FeatureService.get_features(tenant_id) - plan = features.billing.subscription.plan - if plan != CloudPlan.SANDBOX: - knowledge_details = [] - # check tenant - tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) - if not tenant: - continue - # check current owner - current_owner_join = db.session.scalar( - select(TenantAccountJoin) - .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") - .limit(1) - ) - if not current_owner_join: - continue - account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) - if not account: - continue + with session_factory.create_session() as session: + dataset_auto_disable_logs = session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False)) + ).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != CloudPlan.SANDBOX: + knowledge_details = [] + # check tenant + tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + # check current owner + current_owner_join = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) + ) + if not current_owner_join: + continue + account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) + if not account: + continue - dataset_auto_dataset_map = {} # type: ignore + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id)) + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, + language_code="en-US", + to=account.email, + template_context={ + "userName": account.email, + "knowledge_details": knowledge_details, + "url": url, + }, + ) + + # update notified to True for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( - dataset_auto_disable_log.document_id - ) - - for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) - if dataset: - document_count = len(document_ids) - knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") - if knowledge_details: - email_service = get_email_i18n_service() - email_service.send_email( - email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, - language_code="en-US", - to=account.email, - template_context={ - "userName": account.email, - "knowledge_details": knowledge_details, - "url": url, - }, - ) - - # update notified to True - for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - dataset_auto_disable_log.notified = True - db.session.commit() + dataset_auto_disable_log.notified = True + session.commit() end_at = time.perf_counter() logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 78806927bc..97aaea3395 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -17,6 +17,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from core.helper import ssrf_proxy from core.plugin.entities.plugin import PluginDependency from core.trigger.constants import ( @@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.6.0" +CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION class Import(BaseModel): diff --git a/api/services/app_service.py b/api/services/app_service.py index afd98e2975..038c59633a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e6f5f80a6d..894cb05687 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -30,7 +30,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from graphon.file import helpers as file_helpers from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/services/feature_service.py b/api/services/feature_service.py index e4eb9e7582..a6b02630bf 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -3,6 +3,7 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict, Field from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from enums.cloud_plan import CloudPlan from enums.hosted_provider import HostedTrialProvider from services.billing_service import BillingService @@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel): class SystemFeatureModel(BaseModel): + app_dsl_version: str = "" sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" enable_marketplace: bool = False @@ -225,6 +227,7 @@ class FeatureService: @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() + system_features.app_dsl_version = CURRENT_APP_DSL_VERSION cls._fulfill_system_params_from_env(system_features) diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 68ef67dec1..8b4983e5f7 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 968600d1bc..9db6682e10 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -476,7 +476,7 @@ class RagPipelineService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index c96050ce13..1529c2b98f 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator): return TruncationResult(StringSegment(value=fallback_result.value), True) # Apply final fallback - convert to JSON string and truncate - json_str = dumps_with_segments(result.value, ensure_ascii=False) + json_str = dumps_with_segments(result.value) if len(json_str) > self._max_size_bytes: json_str = json_str[: self._max_size_bytes] + "..." return TruncationResult(result=StringSegment(value=json_str), truncated=True) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 8afb565955..e7325f9c6a 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -1075,7 +1075,7 @@ class DraftVariableSaver: filename = f"{self._generate_filename(name)}.txt" else: # For other types, store as JSON - original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False) + original_content_serialized = dumps_with_segments(value_seg.value) content_type = "application/json" filename = f"{self._generate_filename(name)}.json" diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index a68f69a8d3..95c931a46e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import TRIGGER_NODE_TYPES, is_trigger_node_type -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, - normalize_human_input_node_data_for_graph, + adapt_human_input_node_data_for_graph, parse_human_input_delivery_methods, ) from core.workflow.node_factory import ( @@ -112,7 +112,8 @@ class WorkflowService: def __init__(self, session_maker: sessionmaker | None = None): """Initialize WorkflowService with repository dependencies.""" if session_maker is None: - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + session_maker = sessionmaker( + bind=db.engine, expire_on_commit=False) self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) @@ -216,7 +217,8 @@ class WorkflowService: if not app_ids: return set() - stmt = select(App.id).where(App.id.in_(app_ids), App.tenant_id == tenant_id) + stmt = select(App.id).where(App.id.in_( + app_ids), App.tenant_id == tenant_id) return {str(app_id) for app_id in db.session.scalars(stmt).all()} def get_all_published_workflow( @@ -300,7 +302,8 @@ class WorkflowService: ) ) - stmt = stmt.order_by(Workflow.created_at.desc()).limit(limit + 1).offset((page - 1) * limit) + stmt = stmt.order_by(Workflow.created_at.desc()).limit( + limit + 1).offset((page - 1) * limit) workflows = session.scalars(stmt).all() @@ -332,7 +335,8 @@ class WorkflowService: raise WorkflowHashNotEqualError() # validate features structure - self.validate_features_structure(app_model=app_model, features=features) + self.validate_features_structure( + app_model=app_model, features=features) # validate graph structure self.validate_graph_structure(graph=graph) @@ -364,7 +368,8 @@ class WorkflowService: db.session.commit() # trigger app workflow events - app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow) + app_draft_workflow_was_synced.send( + app_model, synced_draft_workflow=workflow) # return draft workflow return workflow @@ -432,7 +437,8 @@ class WorkflowService: raise ValueError("No draft workflow found.") # validate features structure - self.validate_features_structure(app_model=app_model, features=features) + self.validate_features_structure( + app_model=app_model, features=features) workflow.features = json.dumps(features) workflow.updated_by = account.id @@ -453,11 +459,13 @@ class WorkflowService: Secret environment variables are copied server-side from the selected published workflow so the normal draft sync flow stays stateless. """ - source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id) + source_workflow = self.get_published_workflow_by_id( + app_model=app_model, workflow_id=workflow_id) if not source_workflow: raise WorkflowNotFoundError("Workflow not found.") - self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict) + self.validate_features_structure( + app_model=app_model, features=source_workflow.normalized_features_dict) self.validate_graph_structure(graph=source_workflow.graph_dict) draft_workflow = self.get_draft_workflow(app_model=app_model) @@ -474,7 +482,8 @@ class WorkflowService: db.session.add(draft_workflow) db.session.commit() - app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow) + app_draft_workflow_was_synced.send( + app_model, synced_draft_workflow=draft_workflow) return draft_workflow @@ -518,7 +527,8 @@ class WorkflowService: and is_trigger_node_type(node_type_str) ) if trigger_node_count > 2: - raise TriggerNodeLimitExceededError(count=trigger_node_count, limit=2) + raise TriggerNodeLimitExceededError( + count=trigger_node_count, limit=2) # create new workflow workflow = Workflow.new( @@ -540,7 +550,8 @@ class WorkflowService: session.add(workflow) # trigger app workflow events - app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + app_published_workflow_was_updated.send( + app_model, published_workflow=workflow) # return new workflow return workflow @@ -597,7 +608,8 @@ class WorkflowService: session.add(workflow) # trigger app workflow events - app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + app_published_workflow_was_updated.send( + app_model, published_workflow=workflow) return workflow @@ -615,7 +627,8 @@ class WorkflowService: This endpoint only supports conversion between standard workflow and evaluation workflow. """ if target_type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}: - raise ValueError("target_type must be either 'workflow' or 'evaluation'") + raise ValueError( + "target_type must be either 'workflow' or 'evaluation'") if not app_model.workflow_id: raise WorkflowNotFoundError("Published workflow not found") @@ -630,7 +643,8 @@ class WorkflowService: raise WorkflowNotFoundError("Published workflow not found") if workflow.version == Workflow.VERSION_DRAFT: - raise IsDraftWorkflowError("Current effective workflow cannot be a draft version.") + raise IsDraftWorkflowError( + "Current effective workflow cannot be a draft version.") if workflow.type == target_type: return workflow @@ -642,7 +656,8 @@ class WorkflowService: workflow.updated_by = account.id workflow.updated_at = naive_utc_now() - app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + app_published_workflow_was_updated.send( + app_model, published_workflow=workflow) return workflow @@ -660,7 +675,8 @@ class WorkflowService: if not disallowed_nodes: return - formatted_nodes = ", ".join(f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes) + formatted_nodes = ", ".join( + f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes) raise ValueError( "Evaluation workflow cannot contain trigger or human-input nodes. " f"Found disallowed nodes: {formatted_nodes}" @@ -698,12 +714,14 @@ class WorkflowService: ) else: # Check default workspace credential for this provider - self._check_default_tool_credential(workflow.tenant_id, provider) + self._check_default_tool_credential( + workflow.tenant_id, provider) elif node_type == "agent": agent_params = node_data.get("agent_parameters", {}) - model_config = agent_params.get("model", {}).get("value", {}) + model_config = agent_params.get( + "model", {}).get("value", {}) if model_config.get("provider") and model_config.get("model"): self._validate_llm_model_config( workflow.tenant_id, model_config["provider"], model_config["model"] @@ -711,7 +729,8 @@ class WorkflowService: # Validate load balancing credentials for agent model if load balancing is enabled agent_model_node_data = {"model": model_config} - self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id) + self._validate_load_balancing_credentials( + workflow, agent_model_node_data, node_id) # Validate agent tools tools = agent_params.get("tools", {}).get("value", []) @@ -723,9 +742,11 @@ class WorkflowService: if credential_id: from core.helper.credential_utils import check_credential_policy_compliance - check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL) + check_credential_policy_compliance( + credential_id, provider, PluginCredentialType.TOOL) else: - self._check_default_tool_credential(workflow.tenant_id, provider) + self._check_default_tool_credential( + workflow.tenant_id, provider) elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]: model_config = node_data.get("model", {}) @@ -734,11 +755,14 @@ class WorkflowService: if provider and model_name: # Validate that the provider+model combination can fetch valid credentials - self._validate_llm_model_config(workflow.tenant_id, provider, model_name) + self._validate_llm_model_config( + workflow.tenant_id, provider, model_name) # Validate load balancing credentials if load balancing is enabled - self._validate_load_balancing_credentials(workflow, node_data, node_id) + self._validate_load_balancing_credentials( + workflow, node_data, node_id) else: - raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration") + raise ValueError( + f"Node {node_id} ({node_type}): Missing provider or model configuration") except Exception as e: if isinstance(e, ValueError): @@ -780,8 +804,10 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_configurations = assembly.provider_manager.get_configurations(tenant_id) - models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) + provider_configurations = assembly.provider_manager.get_configurations( + tenant_id) + models = provider_configurations.get_models( + provider=provider, model_type=ModelType.LLM) target_model = None for model in models: @@ -792,7 +818,8 @@ class WorkflowService: if target_model: target_model.raise_for_status() else: - raise ValueError(f"Model {model_name} not found for provider {provider}") + raise ValueError( + f"Model {model_name} not found for provider {provider}") except Exception as e: raise ValueError( @@ -840,7 +867,8 @@ class WorkflowService: ) except Exception as e: - raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") + raise ValueError( + f"Failed to validate default credential for tool provider {provider}: {str(e)}") def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None: """ @@ -862,7 +890,8 @@ class WorkflowService: # Check if this model has load balancing enabled if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name): # Get all load balancing configurations for this model - load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name) + load_balancing_configs = self._get_load_balancing_configs( + workflow.tenant_id, provider, model_name) # Validate each load balancing configuration try: for config in load_balancing_configs: @@ -873,7 +902,8 @@ class WorkflowService: config["credential_id"], provider, PluginCredentialType.MODEL ) except Exception as e: - raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") + raise ValueError( + f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool: """ @@ -888,8 +918,10 @@ class WorkflowService: from graphon.model_runtime.entities.model_entities import ModelType # Get provider configurations - provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) - provider_configurations = provider_manager.get_configurations(tenant_id) + provider_manager = create_plugin_provider_manager( + tenant_id=tenant_id) + provider_configurations = provider_manager.get_configurations( + tenant_id) provider_configuration = provider_configurations.get(provider) if not provider_configuration: @@ -930,7 +962,8 @@ class WorkflowService: _, custom_configs = model_load_balancing_service.get_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model" ) - all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs) + all_configs = cast(list[dict[str, Any]], configs) + \ + cast(list[dict[str, Any]], custom_configs) return [config for config in all_configs if config.get("credential_id")] @@ -975,7 +1008,7 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -994,7 +1027,8 @@ class WorkflowService: ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY, ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - default_config = node_class.get_default_config(filters=resolved_filters or None) + default_config = node_class.get_default_config( + filters=resolved_filters or None) if not default_config: return {} @@ -1017,7 +1051,8 @@ class WorkflowService: with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=account.id) + draft_var_srv.prefill_conversation_variable_default_values( + draft_workflow, user_id=account.id) node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) @@ -1031,7 +1066,8 @@ class WorkflowService: workflow=draft_workflow, ) if node_type == BuiltinNodeTypes.START: - start_data = StartNodeData.model_validate(node_data, from_attributes=True) + start_data = StartNodeData.model_validate( + node_data, from_attributes=True) user_inputs = _rebuild_file_for_user_inputs_in_start_node( tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs ) @@ -1066,7 +1102,8 @@ class WorkflowService: user_id=account.id, ) - enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id( + node_config) if enclosing_node_type_and_id: _, enclosing_node_id = enclosing_node_type_and_id else: @@ -1101,12 +1138,15 @@ class WorkflowService: ) repository.save(node_execution) - workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) + workflow_node_execution = self._node_execution_service_repo.get_execution_by_id( + node_execution.id) if workflow_node_execution is None: - raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") + raise ValueError( + f"WorkflowNodeExecution with id {node_execution.id} not found after saving") with sessionmaker(db.engine).begin() as session: - outputs = workflow_node_execution.load_full_outputs(session, storage) + outputs = workflow_node_execution.load_full_outputs( + session, storage) with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( @@ -1118,7 +1158,8 @@ class WorkflowService: node_execution_id=node_execution.id, user=account, ) - draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) + draft_var_saver.save( + process_data=node_execution.process_data, outputs=outputs) enqueue_draft_node_execution_trace( execution=workflow_node_execution, @@ -1245,7 +1286,8 @@ class WorkflowService: rendered_content, outputs, node_data.outputs_field_names() ) - enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id( + node_config) enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( @@ -1280,7 +1322,7 @@ class WorkflowService: raise ValueError("Node type must be human-input.") node_data = HumanInputNodeData.model_validate( - normalize_human_input_node_data_for_graph(node_config["data"]), + adapt_human_input_node_data_for_graph(node_config["data"]), from_attributes=True, ) delivery_method = self._resolve_human_input_delivery_method( @@ -1332,7 +1374,8 @@ class WorkflowService: try: test_service.send_test(context=context, method=delivery_method) except DeliveryTestUnsupportedError as exc: - raise ValueError("Delivery method does not support test send.") from exc + raise ValueError( + "Delivery method does not support test send.") from exc except DeliveryTestError as exc: raise ValueError(str(exc)) from exc @@ -1357,7 +1400,8 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id) + repo = HumanInputFormRepositoryImpl( + tenant_id=app_model.tenant_id, app_id=app_model.id) params = FormCreateParams( workflow_execution_id=None, node_id=node_id, @@ -1377,7 +1421,8 @@ class WorkflowService: with Session(bind=db.engine) as session: recipients = session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id) + select(HumanInputFormRecipient).where( + HumanInputFormRecipient.form_id == form_id) ).all() recipients_data: list[DeliveryTestEmailRecipient] = [] for recipient in recipients: @@ -1388,11 +1433,13 @@ class WorkflowService: try: payload = json.loads(recipient.recipient_payload) except (json.JSONDecodeError, ValueError): - logger.exception("Failed to parse human input recipient payload for delivery test.") + logger.exception( + "Failed to parse human input recipient payload for delivery test.") continue email = payload.get("email") if email: - recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token)) + recipients_data.append(DeliveryTestEmailRecipient( + email=email, form_token=recipient.access_token)) return recipients_data def _build_human_input_node( @@ -1421,9 +1468,11 @@ class WorkflowService: variable_pool=variable_pool, start_at=time.perf_counter(), ) + node_data = HumanInputNode.validate_node_data( + adapt_human_input_node_data_for_graph(node_config["data"])) node = HumanInputNode( - id=node_config["id"], - config=node_config, + node_id=node_config["id"], + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=DifyHumanInputNodeRuntime(run_context), @@ -1441,7 +1490,8 @@ class WorkflowService: ) -> VariablePool: with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) + draft_var_srv.prefill_conversation_variable_default_values( + workflow, user_id=user_id) variable_pool = VariablePool() add_variables_to_pool( @@ -1519,7 +1569,8 @@ class WorkflowService: Returns: WorkflowNodeExecution: The execution result """ - node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn) + node, node_run_result, run_succeeded, error = self._execute_node_safely( + invoke_node_fn) # Create base node execution node_execution = WorkflowNodeExecution( @@ -1535,7 +1586,8 @@ class WorkflowService: ) # Populate execution result data - self._populate_execution_result(node_execution, node_run_result, run_succeeded, error) + self._populate_execution_result( + node_execution, node_run_result, run_succeeded, error) return node_execution @@ -1564,7 +1616,8 @@ class WorkflowService: # Apply error strategy if node failed if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.error_strategy: - node_run_result = self._apply_error_strategy(node, node_run_result) + node_run_result = self._apply_error_strategy( + node, node_run_result) run_succeeded = node_run_result.status in ( WorkflowNodeExecutionStatus.SUCCEEDED, @@ -1595,7 +1648,8 @@ class WorkflowService: status=WorkflowNodeExecutionStatus.EXCEPTION, error=node_run_result.error, inputs=node_run_result.inputs, - metadata={WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy}, + metadata={ + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy}, outputs=error_outputs, ) @@ -1609,10 +1663,12 @@ class WorkflowService: """Populate node execution with result data.""" if run_succeeded and node_run_result: node_execution.inputs = ( - WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None + WorkflowEntry.handle_special_values( + node_run_result.inputs) if node_run_result.inputs else None ) node_execution.process_data = ( - WorkflowEntry.handle_special_values(node_run_result.process_data) + WorkflowEntry.handle_special_values( + node_run_result.process_data) if node_run_result.process_data else None ) @@ -1641,7 +1697,8 @@ class WorkflowService: workflow_converter = WorkflowConverter() if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}: - raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") + raise ValueError( + f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow new_app: App = workflow_converter.convert_to_workflow( @@ -1678,7 +1735,8 @@ class WorkflowService: # start node and trigger node cannot coexist if BuiltinNodeTypes.START in node_types: if any(is_trigger_node_type(nt) for nt in node_types): - raise ValueError("Start node and trigger nodes cannot coexist in the same workflow") + raise ValueError( + "Start node and trigger nodes cannot coexist in the same workflow") for node in node_configs: node_data = node.get("data", {}) @@ -1713,7 +1771,8 @@ class WorkflowService: from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) + HumanInputNodeData.model_validate( + adapt_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") @@ -1730,7 +1789,8 @@ class WorkflowService: :param data: Dictionary containing fields to update :return: Updated workflow or None if not found """ - stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) + stmt = select(Workflow).where(Workflow.id == workflow_id, + Workflow.tenant_id == tenant_id) workflow = session.scalar(stmt) if not workflow: @@ -1759,7 +1819,8 @@ class WorkflowService: :raises: WorkflowInUseError if workflow is in use :raises: DraftWorkflowDeletionError if workflow is a draft version """ - stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) + stmt = select(Workflow).where(Workflow.id == workflow_id, + Workflow.tenant_id == tenant_id) workflow = session.scalar(stmt) if not workflow: @@ -1767,14 +1828,16 @@ class WorkflowService: # Check if workflow is a draft version if workflow.version == Workflow.VERSION_DRAFT: - raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") + raise DraftWorkflowDeletionError( + "Cannot delete draft workflow versions") # Check if this workflow is currently referenced by an app app_stmt = select(App).where(App.workflow_id == workflow_id) app = session.scalar(app_stmt) if app: # Cannot delete a workflow that's currently in use by an app - raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'") + raise WorkflowInUseError( + f"Cannot delete workflow that is currently in use by app '{app.id}'") # Don't use workflow.tool_published as it's not accurate for specific workflow versions # Check if there's a tool provider using this specific workflow version @@ -1788,7 +1851,8 @@ class WorkflowService: if tool_provider: # Cannot delete a workflow that's published as a tool - raise WorkflowInUseError("Cannot delete workflow that is published as a tool") + raise WorkflowInUseError( + "Cannot delete workflow that is published as a tool") session.delete(workflow) return True @@ -1837,11 +1901,13 @@ def _setup_variable_pool( build_bootstrap_variables( system_variables=system_variable, environment_variables=workflow.environment_variables, - conversation_variables=cast(list[Variable], conversation_variables), + conversation_variables=cast( + list[Variable], conversation_variables), ), ) if is_start_node_type(node_type): - add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) + add_node_inputs_to_pool( + variable_pool, node_id=node_id, inputs=user_inputs) return variable_pool @@ -1857,7 +1923,8 @@ def _rebuild_file_for_user_inputs_in_start_node( if variable.variable not in user_inputs: continue value = user_inputs[variable.variable] - file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type) + file = _rebuild_single_file( + tenant_id=tenant_id, value=value, variable_entity_type=variable.type) inputs_copy[variable.variable] = file return inputs_copy @@ -1865,15 +1932,18 @@ def _rebuild_file_for_user_inputs_in_start_node( def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]: if variable_entity_type == VariableEntityType.FILE: if not isinstance(value, dict): - raise ValueError(f"expected dict for file object, got {type(value)}") + raise ValueError( + f"expected dict for file object, got {type(value)}") return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller) elif variable_entity_type == VariableEntityType.FILE_LIST: if not isinstance(value, list): - raise ValueError(f"expected list for file list object, got {type(value)}") + raise ValueError( + f"expected list for file list object, got {type(value)}") if len(value) == 0: return [] if not isinstance(value[0], dict): - raise ValueError(f"expected dict for first element in the file list, got {type(value)}") + raise ValueError( + f"expected dict for first element in the file list, got {type(value)}") return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller) else: raise Exception("unreachable") diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index f8ae3f4b6e..2a60be7762 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail from graphon.runtime import GraphRuntimeState, VariablePool diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index b5318aaa2b..2392084c36 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,5 +1,6 @@ from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult, StreamCompletedEvent @@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=_GP(), graph_runtime_state=_GS(vp), ) diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3476c292b..aaa6092993 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import NodeRunResult from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeNodeData from graphon.nodes.code.limits import CodeNodeLimits from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -64,8 +65,8 @@ def init_code_node(code_config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( - id=str(uuid.uuid4()), - config=code_config, + node_id=str(uuid.uuid4()), + config=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index aa6cf1e021..b9f7b9575b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.file.file_manager import file_manager from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -75,8 +75,8 @@ def init_http_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=graph_config["nodes"][1], + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index fa5d63cfbf..3eead70163 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.llm.node import LLMNode from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory @@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode: llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 52886855b8..f2eabb86c3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -11,6 +11,7 @@ from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance @@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = ParameterExtractorNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 9e3e1a47e3..e2e0723fb8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError @@ -86,8 +87,8 @@ def test_execute_template_transform(): assert graph is not None node = TemplateTransformNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index f9ec51ee10..a8e9422c1e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import StreamCompletedEvent from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool.tool_node import ToolNode from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -60,8 +61,8 @@ def init_tool_node(config: dict): tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 14d5740072..6524d6ce61 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index da4f8847d6..5aed230cd4 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -101,8 +101,8 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, + node_id="start", + config=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -116,8 +116,8 @@ def _build_graph( ], ) human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, + node_id="human", + config=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -130,8 +130,8 @@ def _build_graph( desc=None, ) end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, + node_id="end", + config=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 2e207ddc67..35e41035df 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase): file_related_id = related_id return File( - id=str(uuid4()), # Generate new UUID for File.id + file_id=str(uuid4()), # Generate new UUID for File.id tenant_id=tenant_id, - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=transfer_method, related_id=file_related_id, remote_url=remote_url, diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index aaf9a85d60..54b7afc018 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -271,7 +271,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from core.workflow.human_input_compat import DeliveryMethodType + from core.workflow.human_input_adapter import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 18c5320d0a..80f9083e81 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 21a54e909e..ed75363f3b 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -8,7 +8,7 @@ import pytest from sqlalchemy.engine import Engine from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 328bdbf055..95a867dbb5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 6ff3b19362..e91c0a0597 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: file_list = [ File( tenant_id="t1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="http://u", ) diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index b19a1740eb..22b80b748e 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", filename="test.jpg", @@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", filename="test.jpg", diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 14c35a9ed5..4fb8ecf784 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import ( ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError +from fields._value_type_serializer import serialize_value_type +from graphon.variables import StringSegment from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -284,6 +286,32 @@ class TestConversationVariableResponseModels: assert response.created_at == int(created_at.timestamp()) assert response.updated_at == int(created_at.timestamp()) + def test_variable_response_normalizes_string_value_type_alias(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER.value, + } + ) + + assert response.value_type == "number" + + def test_variable_response_normalizes_callable_exposed_type(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()), + } + ) + + assert response.value_type == "string" + + def test_serialize_value_type_supports_segments_and_mappings(self): + assert serialize_value_type(StringSegment(value="hello")) == "string" + assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number" + def test_variable_pagination_response(self): response = ConversationVariableInfiniteScrollPaginationResponse.model_validate( { diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 3ab63aed25..dd6cd0e919 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: def create_test_file(self, file_id: str = "test_file_1") -> File: """Create a test File object""" return File( - id=file_id, - type=FileType.DOCUMENT, + file_id=file_id, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", filename=f"{file_id}.txt", diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index a04a7b7576..6104b8d6ca 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401 from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow import node_factory as node_factory_module from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph @@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]): def version(cls) -> str: return "1" - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) + def __init__( + self, + node_id: str, + config: _StubToolNodeData, + *, + graph_init_params, + graph_runtime_state, + **_kwargs: Any, + ) -> None: + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) def _get_error_strategy(self): return self._node_data.error_strategy @@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node + original_resolve_node_class = node_factory_module.resolve_workflow_node_class - def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) - node_data = typed_node_config["data"] - if node_data.type == BuiltinNodeTypes.TOOL: - return _StubToolNode( - id=str(typed_node_config["id"]), - config=typed_node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, typed_node_config) + def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + if node_type == BuiltinNodeTypes.TOOL: + return _StubToolNode + return original_resolve_node_class(node_type=node_type, node_version=node_version) - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class) def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index cddd03f4b0..701863b927 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -26,8 +26,8 @@ def _build_file( extension: str | None = None, ) -> File: return File( - id="file-id", - type=FileType.IMAGE, + file_id="file-id", + file_type=FileType.IMAGE, transfer_method=transfer_method, reference=reference, remote_url=remote_url, @@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M assert runtime.multimodal_send_format == "url" - with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get: assert runtime.http_get("http://example", follow_redirects=False) == "response" mock_get.assert_called_once_with("http://example", follow_redirects=False) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index c4bfb23272..30a068f4c5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): - self.id = id + def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = node_id self.config = config self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index 81315d2508..deeac49bbc 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", extension=".png", @@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", extension=".pdf", diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index a0b2820157..aeca2e3afd 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None: # Assert assert simple_provider.provider == "openai" - assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.label.en_us == "OpenAI" assert simple_provider.supported_model_types == [ModelType.LLM] diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index bb6e40e224..8cb0938575 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType def test_file(): file = File( - id="test-file", + file_id="test-file", tenant_id="test-tenant-id", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="test-related-id", filename="image.png", @@ -25,27 +25,21 @@ def test_file(): assert file.size == 67 -def test_file_model_validate_accepts_legacy_tenant_id(): - data = { - "id": "test-file", - "tenant_id": "test-tenant-id", - "type": "image", - "transfer_method": "tool_file", - "related_id": "test-related-id", - "filename": "image.png", - "extension": ".png", - "mime_type": "image/png", - "size": 67, - "storage_key": "test-storage-key", - "url": "https://example.com/image.png", - # Extra legacy fields - "tool_file_id": "tool-file-123", - "upload_file_id": "upload-file-456", - "datasource_file_id": "datasource-file-789", - } +def test_file_constructor_accepts_legacy_tenant_id(): + file = File( + file_id="test-file", + tenant_id="test-tenant-id", + file_type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + tool_file_id="tool-file-123", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) - file = File.model_validate(data) - - assert file.related_id == "test-related-id" + assert file.related_id == "tool-file-123" assert file.storage_key == "test-storage-key" assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 3b5c5e6597..d9fed9ae2a 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -1,11 +1,17 @@ from unittest.mock import MagicMock, patch +import httpx import pytest from core.helper.ssrf_proxy import ( SSRF_DEFAULT_MAX_RETRIES, + SSRFProxy, _get_user_provided_host_header, + _to_graphon_http_response, + graphon_ssrf_proxy, make_request, + max_retries_exceeded_error, + request_error, ) @@ -174,3 +180,56 @@ class TestFollowRedirectsParameter: call_kwargs = mock_client.request.call_args.kwargs assert call_kwargs.get("follow_redirects") is True + + +def test_to_graphon_http_response_preserves_httpx_response_fields() -> None: + response = httpx.Response( + 201, + headers={"X-Test": "1"}, + content=b"payload", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + wrapped = _to_graphon_http_response(response) + + assert wrapped.status_code == 201 + assert wrapped.headers == {"x-test": "1", "content-length": "7"} + assert wrapped.content == b"payload" + assert wrapped.url == "https://example.com/resource" + assert wrapped.reason_phrase == "Created" + assert wrapped.text == "payload" + + +def test_ssrf_proxy_exposes_expected_error_types() -> None: + proxy = SSRFProxy() + + assert proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert proxy.request_error is request_error + assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert graphon_ssrf_proxy.request_error is request_error + + +@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"]) +def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None: + response = httpx.Response( + 200, + headers={"X-Test": "1"}, + content=b"ok", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method: + wrapped = getattr(graphon_ssrf_proxy, method_name)( + "https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + + mock_method.assert_called_once_with( + url="https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + assert wrapped.status_code == 200 + assert wrapped.url == "https://example.com/resource" + assert wrapped.content == b"ok" diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 249ecb5006..c4fd970562 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index 68aa130518..88bf555594 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -56,7 +56,7 @@ class TestPluginModelRuntime: assert len(providers) == 1 assert providers[0].provider == "langgenius/openai/openai" assert providers[0].provider_name == "openai" - assert providers[0].label.en_US == "OpenAI" + assert providers[0].label.en_us == "OpenAI" client.fetch_model_providers.assert_called_once_with("tenant") def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index d49b6e4b71..00a4207786 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -466,7 +466,7 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): file_param = File( tenant_id="tenant-1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/file.png", storage_key="", @@ -499,14 +499,14 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): file_one = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/a.txt", storage_key="", ) file_two = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/b.txt", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 395d392127..e536c0831f 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg files = [ File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="", @@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files(): completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query(): memory = MagicMock(spec=TokenBufferMemory) prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 803afa54d7..28966242d8 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 9f9ea33695..5308c8e7b3 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from graphon.model_runtime.entities.message_entities import UserPromptMessage # from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule # from graphon.model_runtime.entities.provider_entities import ProviderEntity -# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 64eb89590a..0220fb6d4a 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,12 +1,14 @@ """Primarily used for testing merged cell scenarios""" +import gc import io import os import tempfile +import warnings from collections import UserDict from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from docx import Document @@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): WordExtractor("not-a-file", "tenant", "user") -def test_del_closes_temp_file(): +def test_close_closes_temp_file(): extractor = object.__new__(WordExtractor) + extractor._closed = False extractor.temp_file = MagicMock() - WordExtractor.__del__(extractor) + extractor.close() extractor.temp_file.close.assert_called_once() +def test_close_is_idempotent(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + + extractor.close() + extractor.close() + + extractor.temp_file.close.assert_called_once() + + +def test_close_handles_async_close_mock(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + extractor.temp_file.close = AsyncMock() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + extractor.close() + gc.collect() + + extractor.temp_file.close.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] + + def test_extract_images_handles_invalid_external_cases(monkeypatch): class FakeTargetRef: def __contains__(self, item): diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 8be1ac318c..18ae9fafc8 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -14,7 +14,7 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 1297a95df1..4248782d93 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -21,7 +21,7 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index f17927f16b..eab0176f41 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -6,9 +6,9 @@ from models.workflow import Workflow def test_file_to_dict(): file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="storage_key", diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 72052c8c05..9e07ea1b6d 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,8 +1,9 @@ import dataclasses +from typing import Annotated import orjson import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Discriminator, Tag from core.helper import encrypter from core.workflow.system_variables import build_bootstrap_variables, build_system_variables @@ -12,17 +13,18 @@ from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayAnySegment, + ArrayBooleanSegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, NoneSegment, ObjectSegment, Segment, - SegmentUnion, StringSegment, get_segment_discriminator, ) @@ -47,6 +49,26 @@ from graphon.variables.variables import ( StringVariable, Variable, ) +from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup + +type SegmentUnion = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] + ), + Discriminator(get_segment_discriminator), +] def _build_variable_pool( @@ -123,7 +145,7 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, @@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad: assert restored == model def test_all_segments_serialization(self): - """Test serialization/deserialization of all segment types""" + """Test file-aware segment serialization through Dify's model boundary.""" # Create one instance of each segment type test_file = create_test_file() @@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Segments(segments=all_segments) json_str = model.model_dump_json() - loaded = _Segments.model_validate_json(json_str) + loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all segments are preserved assert len(loaded.segments) == len(all_segments) @@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad: assert loaded_segment.value == original.value def test_all_variables_serialization(self): - """Test serialization/deserialization of all variable types""" + """Test file-aware variable serialization through Dify's model boundary.""" # Create one instance of each variable type test_file = create_test_file() @@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Variables(variables=all_variables) json_str = model.model_dump_json() - loaded = _Variables.model_validate_json(json_str) + loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all variables are preserved assert len(loaded.variables) == len(all_variables) diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 94e788edb2..317fe99d37 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -35,7 +35,7 @@ def create_test_file( """Factory function to create File objects for testing.""" return File( tenant_id="test-tenant", - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 76b2984a4b..9f3e3b00b9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -1,12 +1,13 @@ -""" -Mock node factory for testing workflows with third-party service dependencies. +"""Mock node factory for third-party-service workflow tests. -This module provides a MockNodeFactory that automatically detects and mocks nodes -requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). +The factory follows the same config adaptation path as production +`DifyNodeFactory.create_node()`, but swaps selected node classes for mock +implementations before instantiation. """ from typing import TYPE_CHECKING, Any +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType @@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory): :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) + node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = typed_node_config["id"] - # Create mock node instance mock_class = self._mock_node_types[node_type] + resolved_node_data = self._validate_resolved_node_data(mock_class, node_data) if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory): ) elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory): BuiltinNodeTypes.PARAMETER_EXTRACTOR, }: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory): ) else: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(typed_node_config) + return super().create_node(node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 971b9b2bbf..f9819c47ec 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -55,13 +55,14 @@ class MockNodeMixin: def __init__( self, - id: str, - config: Mapping[str, Any], + node_id: str, + config: Any, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, **kwargs: Any, - ): + ) -> None: if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) @@ -96,7 +97,7 @@ class MockNodeMixin: kwargs.setdefault("message_transformer", MagicMock()) super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 55a329eba9..75bc6d05f7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -139,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( - id=start_config["id"], - config=start_config, + node_id=start_config["id"], + config=StartNodeData(title="Start", variables=[]), graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -154,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_a_config = {"id": "human_a", "data": human_data.model_dump()} human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, + node_id=human_a_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -164,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_b_config = {"id": "human_b", "data": human_data.model_dump()} human_b = HumanInputNode( - id=human_b_config["id"], - config=human_b_config, + node_id=human_b_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -182,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor ) end_config = {"id": "end", "data": end_data.model_dump()} end_node = EndNode( - id=end_config["id"], - config=end_config, + node_id=end_config["id"], + config=end_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 9c0ad25b58..76b4cd1ef4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -9,6 +9,7 @@ from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.answer.answer_node import AnswerNode +from graphon.nodes.answer.entities import AnswerNodeData from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -66,20 +67,15 @@ def test_execute_answer(): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - } - node = AnswerNode( - id=str(uuid.uuid4()), + node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, + config=AnswerNodeData( + title="123", + type="answer", + answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + ), ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 9cceadde49..d7ef781732 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,5 +1,6 @@ from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -77,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=gp, graph_runtime_state=gs, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index a3cadc0681..2e89a2da3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -12,7 +12,7 @@ from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.file.file_manager import file_manager from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config(): assert default_config["retry_config"]["max_retries"] == 7 -def test_get_default_config_with_malformed_http_request_config_raises_value_error(): - with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"): +def test_get_default_config_with_malformed_http_request_config_raises_type_error(): + with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"): HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"}) @@ -114,8 +114,8 @@ def _build_http_node( start_at=time.perf_counter(), ) return HttpRequestNode( - id="http-node", - config=node_config, + node_id="http-node", + config=HttpRequestNodeData.model_validate(node_data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index 1d6a4da7c4..07430498e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,4 +1,4 @@ -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index c0e21d0bf7..0659984c76 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -19,7 +19,7 @@ from core.repositories.human_input_repository import ( HumanInputFormRecipientEntity, HumanInputFormRepository, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryMethodType, EmailDeliveryConfig, EmailDeliveryMethod, @@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): entity.status_value = HumanInputFormStatus.SUBMITTED +def _build_human_input_node( + *, + node_id: str, + node_data: HumanInputNodeData | Mapping[str, Any], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + runtime: DifyHumanInputNodeRuntime, +) -> HumanInputNode: + typed_node_data = ( + node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data) + ) + return HumanInputNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + runtime=runtime, + ) + + class TestDeliveryMethod: """Test DeliveryMethod entity.""" @@ -239,7 +259,7 @@ class TestUserAction: data[field_name] = value with pytest.raises(ValidationError) as exc_info: - UserAction(**data) + UserAction.model_validate(data) errors = exc_info.value.errors() assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) @@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent: form_repository = InMemoryHumanInputFormRepository() runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index bc98028d5b..4a9438b14f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -11,6 +11,7 @@ from graphon.graph_events import ( NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) +from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool @@ -25,6 +26,28 @@ class _FakeFormRepository: return self._form +def _create_human_input_node( + *, + config: dict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + repo: _FakeFormRepository, +) -> HumanInputNode: + node_data = ( + config["data"] + if isinstance(config["data"], HumanInputNodeData) + else HumanInputNodeData.model_validate(config["data"]) + ) + return HumanInputNode( + node_id=config["id"], + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + ) + + def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( @@ -80,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) @@ -145,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode: ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 82cc734274..8ffce39cd6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -5,6 +5,7 @@ import pytest from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams +from graphon.nodes.iteration.entities import IterationNodeData from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode from graphon.runtime import ( @@ -44,17 +45,14 @@ def _build_iteration_node( ) -> IterationNode: init_params = build_test_graph_init_params(graph_config=graph_config) return IterationNode( - id="iteration-node", - config={ - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration-node", "output"], - "start_node_id": start_node_id, - }, - }, + node_id="iteration-node", + config=IterationNodeData( + type="iteration", + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id=start_node_id, + ), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index a6fca1bfb4..f254fc3d09 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -93,6 +93,25 @@ def sample_chunks(): } +def _build_node( + *, + node_id: str, + node_data: KnowledgeIndexNodeData | dict[str, object], + graph_init_params, + graph_runtime_state, +) -> KnowledgeIndexNode: + return KnowledgeIndexNode( + node_id=node_id, + config=( + node_data + if isinstance(node_data, KnowledgeIndexNodeData) + else KnowledgeIndexNodeData.model_validate(node_data) + ), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + class TestKnowledgeIndexNode: """ Test suite for KnowledgeIndexNode. @@ -115,9 +134,9 @@ class TestKnowledgeIndexNode: } # Act - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -143,9 +162,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -176,9 +195,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -212,9 +231,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -269,9 +288,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -332,9 +351,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -383,9 +402,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -440,9 +459,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -498,9 +517,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -536,9 +555,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -583,9 +602,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 45e8ae7d20..e923ee761b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -14,7 +14,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( SingleRetrievalConfig, ) from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import ( + KnowledgeRetrievalNode, + _normalize_metadata_filter_scalar, + _normalize_metadata_filter_sequence_item, +) from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus @@ -85,6 +89,12 @@ def sample_node_data(): ) +def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None: + assert _normalize_metadata_filter_scalar(3) == 3 + assert _normalize_metadata_filter_scalar(True) == "True" + assert _normalize_metadata_filter_sequence_item(4) == "4" + + class TestKnowledgeRetrievalNode: """ Test suite for KnowledgeRetrievalNode. @@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode: # Act node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -470,8 +480,8 @@ class TestFetchDatasetRetriever: config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -507,8 +517,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -562,8 +572,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -610,8 +620,8 @@ class TestFetchDatasetRetriever: mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -671,8 +681,8 @@ class TestFetchDatasetRetriever: node_id = str(uuid.uuid4()) config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index eca34f05be..388654f279 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock import pytest @@ -5,6 +6,7 @@ import pytest from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.entities import ListOperatorNodeData from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment @@ -13,11 +15,28 @@ from graphon.variables import ArrayNumberSegment, ArrayStringSegment class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" + @staticmethod + def _build_node(*, config, graph_init_params, graph_runtime_state): + return ListOperatorNode( + node_id="test", + config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + @staticmethod + def _filter_by(comparison_operator: str, value: str) -> dict[str, object]: + return { + "enabled": True, + "conditions": [{"comparison_operator": comparison_operator, "value": value}], + } + @pytest.fixture def mock_graph_runtime_state(self): """Create mock GraphRuntimeState.""" mock_state = MagicMock(spec=GraphRuntimeState) mock_variable_pool = MagicMock() + mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value) mock_state.variable_pool = mock_variable_pool return mock_state @@ -45,9 +64,8 @@ class TestListOperatorNode: def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable - return ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + return self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -64,9 +82,8 @@ class TestListOperatorNode: "limit": {"enabled": False}, } - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -109,9 +126,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=[]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -128,11 +144,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "contains", - "value": "app", - }, + "filter_by": self._filter_by("contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -140,9 +152,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -157,11 +168,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "not contains", - "value": "app", - }, + "filter_by": self._filter_by("not contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -169,9 +176,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -186,11 +192,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "5", - }, + "filter_by": self._filter_by(">", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -198,9 +200,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -226,9 +227,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -254,9 +254,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -282,9 +281,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -299,11 +297,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "3", - }, + "filter_by": self._filter_by(">", "3"), "order_by": { "enabled": True, "value": "desc", @@ -317,9 +311,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -341,9 +334,8 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = None - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -366,9 +358,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["first", "middle", "last"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -384,11 +375,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "start with", - "value": "app", - }, + "filter_by": self._filter_by("start with", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -396,9 +383,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -413,11 +399,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "end with", - "value": "le", - }, + "filter_by": self._filter_by("end with", "le"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -425,9 +407,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -442,11 +423,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "=", - "value": "5", - }, + "filter_by": self._filter_by("=", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -454,9 +431,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -471,11 +447,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "≠", - "value": "5", - }, + "filter_by": self._filter_by("≠", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -483,9 +455,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -511,9 +482,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 4186bbdc93..212ad07bd3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -71,8 +71,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -95,6 +95,8 @@ def variable_pool() -> VariablePool: def _fetch_prompt_messages_with_mocked_content(content): variable_pool = VariablePool.empty() model_instance = mock.MagicMock(spec=ModelInstance) + model_schema = mock.MagicMock() + model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text" prompt_template = [ LLMNodeChatModelMessage( text="You are a classifier.", @@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( "graphon.nodes.llm.llm_utils.fetch_model_schema", - return_value=mock.MagicMock(features=[]), + return_value=model_schema, ), mock.patch( "graphon.nodes.llm.llm_utils.handle_list_messages", diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index b1f81b6c48..c707cf28cd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -140,8 +140,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -205,14 +205,10 @@ def llm_node( mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers(): def test_fetch_files_with_file_segment(): file = File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment(): def test_fetch_files_with_array_file_segment(): files = [ File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", storage_key="", ), File( - id="2", - type=FileType.IMAGE, + file_id="2", + file_type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", @@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/png", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/jpg", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: image_b64_data = base64.b64encode(image_raw_data).decode() mock_saved_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", extension=".png", @@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable file_saver=file_saver, file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, reasoning_format="separated", ) ) @@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=_build_prepared_llm_mock(), reasoning_format="separated", request_start_time=1.0, @@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors(): file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=model_instance, ) ) diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index bc44ececd8..892f6cc586 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -13,6 +13,28 @@ from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params +def _build_template_transform_node( + *, + node_data, + graph_init_params, + graph_runtime_state, + node_id: str = "test_node", + **kwargs, +) -> TemplateTransformNode: + typed_node_data = ( + node_data + if isinstance(node_data, TemplateTransformNodeData) + else TemplateTransformNodeData.model_validate(node_data) + ) + return TemplateTransformNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + **kwargs, + ) + + class TestTemplateTransformNode: """Comprehensive test suite for TemplateTransformNode.""" @@ -59,9 +81,8 @@ class TestTemplateTransformNode: def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -75,9 +96,8 @@ class TestTemplateTransformNode: def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -88,9 +108,8 @@ class TestTemplateTransformNode: def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -108,9 +127,8 @@ class TestTemplateTransformNode: } mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -143,9 +161,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() with pytest.raises(ValueError, match="max_output_length must be a positive integer"): - TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -170,9 +187,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -198,9 +214,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Value: " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -218,9 +233,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -238,9 +252,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -260,9 +273,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "1234567890" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -302,9 +314,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -375,8 +386,8 @@ class TestTemplateTransformNode: ) assert mapping == { - "node_123.var1": ["sys", "input1"], - "node_123.empty_selector": [], + "node_123.var1": ("sys", "input1"), + "node_123.empty_selector": (), } def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): @@ -409,9 +420,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a static message." - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -448,9 +458,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Total: $31.5" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -477,9 +486,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -507,9 +515,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index 636237e56e..a846efbb43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -4,6 +4,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, @@ -37,15 +38,13 @@ def mock_graph_runtime_state(): def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): node = TemplateTransformNode( - id="test_node", - config={ - "id": "test_node", - "data": { - "title": "Template Transform", - "variables": [], - "template": "hello", - }, - }, + node_id="test_node", + config=TemplateTransformNodeData( + title="Template Transform", + type="template-transform", + variables=[], + template="hello", + ), graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=MagicMock(), @@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri assert mapping == { "node_123.validated": ["sys", "input1"], - "node_123.raw": ["sys", "input2"], + "node_123.raw": ("sys", "input2"), } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 0522dd9d14..364408ead6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -7,7 +7,6 @@ from core.workflow.node_runtime import resolve_dify_run_context from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import GraphRuntimeState, VariablePool @@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - }, - } - ) +def _build_node_config() -> dict[str, object]: + return { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + ), + } + + +def _build_node_data() -> _SampleNodeData: + return _build_node_config()["data"] # type: ignore[return-value] def test_node_hydrates_data_during_initialization(): @@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization(): init_params, runtime_state = _build_context(graph_config) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum(): ) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error(): def test_base_node_data_keeps_dict_style_access_compatibility(): - node_data = _SampleNodeData.model_validate( - { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - } - ) + node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar") assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" @@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): def test_node_hydration_preserves_compatibility_extra_fields(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) - node_config = NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - "compat_flag": True, - }, - } - ) + node_config = { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + compat_flag=True, + ), + } node = _SampleNode( - id="node-1", - config=node_config, + node_id="node-1", + config=node_config["data"], graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 87ec2d5bce..dd75b32593 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -11,14 +11,16 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, + _extract_text_from_file, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from graphon.variables import ArrayFileSegment +from graphon.variables import ArrayFileSegment, FileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params): title="Test Document Extractor", variable_selector=["node_id", "variable_name"], ) - node_config = {"id": "test_node_id", "data": node_data.model_dump()} http_client = Mock() node = DocumentExtractorNode( - id="test_node_id", - config=node_config, + node_id="test_node_id", + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=Mock(), http_client=http_client, @@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"] - mock_excel_instance.parse.side_effect = [df, Exception("Parse error")] + mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_mixed_content" @@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"] - mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")] + mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_all_bad_sheets" @@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert mock_excel_instance.parse.call_count == 2 +@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook")) +def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file): + with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"): + _extract_text_from_excel(b"broken") + + @patch("pandas.ExcelFile") def test_extract_text_from_excel_numeric_type_column(mock_excel_file): """Test extracting text from Excel file with numeric column names.""" @@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): assert expected_manual == result +@pytest.mark.parametrize( + ("extension", "mime_type"), + [ + (".xlsx", "text/plain"), + (None, "application/vnd.ms-excel"), + ], +) +def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type): + file = Mock(spec=File) + file.extension = extension + file.mime_type = mime_type + + with ( + patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"excel", + ), + patch( + "graphon.nodes.document_extractor.node._extract_text_from_excel", + return_value="excel text", + ) as mock_extract, + ): + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + assert result == "excel text" + mock_extract.assert_called_once_with(b"excel") + + +def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): + file = Mock(spec=File) + file.extension = None + file.mime_type = None + + with patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"unknown", + ): + with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"): + _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + +def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_list = Mock(spec=ArrayFileSegment) + file_list.value = [Mock(spec=File)] + mock_graph_runtime_state.variable_pool.get.return_value = file_list + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("bad file"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "bad file" + + +def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("single file failed"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "single file failed" + + +def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + return_value="single file text", + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs == {"text": "single file text"} + + def _make_docx_zip(use_backslash: bool) -> bytes: """Helper to build a minimal in-memory DOCX zip. diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 782750e02e..aa9a1360b0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -19,6 +19,20 @@ from graphon.variables import ArrayFileSegment from tests.workflow_test_utils import build_test_graph_init_params +def _build_if_else_node( + *, + node_data: IfElseNodeData | dict[str, object], + init_params, + graph_runtime_state, +) -> IfElseNode: + return IfElseNode( + node_id=str(uuid.uuid4()), + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), + ) + + def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} @@ -61,9 +75,8 @@ def test_execute_if_else_result_true(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "and", @@ -104,13 +117,8 @@ def test_execute_if_else_result_true(): {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -155,9 +163,8 @@ def test_execute_if_else_result_false(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "or", @@ -174,13 +181,8 @@ def test_execute_if_else_result_false(): }, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -222,11 +224,6 @@ def test_array_file_contains_file_name(): ], ) - node_config = { - "id": "if-else", - "data": node_data.model_dump(), - } - # Create properly configured mock for graph_init_params graph_init_params = Mock() graph_init_params.workflow_id = "test_workflow" @@ -242,17 +239,12 @@ def test_array_file_contains_file_name(): } } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=graph_init_params, - graph_runtime_state=Mock(), - config=node_config, - ) + node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock()) node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", @@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition): "logical_operator": "and", "conditions": [condition.model_dump()], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() @@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions(): ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={ - "id": "if-else", - "data": node_data, - }, ) # Mock db.session.close() @@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure(): } ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b217e4e8e7..465a4c0ff4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -19,6 +19,15 @@ from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract from graphon.variables import ArrayFileSegment +def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode: + return ListOperatorNode( + node_id="test_node_id", + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=MagicMock(), + ) + + @pytest.fixture def list_operator_node(): config = { @@ -35,10 +44,6 @@ def list_operator_node(): "title": "Test Title", } node_data = ListOperatorNodeData.model_validate(config) - node_config = { - "id": "test_node_id", - "data": node_data.model_dump(), - } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() graph_init_params.workflow_id = "test_workflow" @@ -54,12 +59,7 @@ def list_operator_node(): } } - node = ListOperatorNode( - id="test_node_id", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=MagicMock(), - ) + node = _build_list_operator_node(node_data, graph_init_params) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node @@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node): files = [ File( filename="image1.jpg", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", ), File( filename="document1.pdf", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", ), File( filename="image2.png", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", ), File( filename="audio1.mp3", - type=FileType.AUDIO, + file_type=FileType.AUDIO, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", extension=".txt", @@ -156,7 +156,7 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, extension=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 543f9878de..5655f80737 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables): inputs=user_inputs, ) - config = { - "id": "start", - "data": StartNodeData(title="Start", variables=variables).model_dump(), - } + node_data = StartNodeData(title="Start", variables=variables) graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, @@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables): ) return StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, @@ -109,7 +106,7 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"): + with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"): node._run() @@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot(): inputs={"profile": {"age": 20, "name": "Tom"}}, ) - config = { - "id": "start", - "data": StartNodeData( - title="Start", - variables=[ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ], - ).model_dump(), - } + node_data = StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index c806181340..284af68319 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -13,6 +13,7 @@ from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment @@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode: runtime = _StubToolRuntime() node = ToolNode( - id="node-instance", - config=config, + node_id="node-instance", + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, @@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode: return node -def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: +def _collect_events(generator: Generator) -> list[Any]: events: list[Any] = [] try: while True: events.append(next(generator)) - except StopIteration as stop: - return events, stop.value + except StopIteration: + return events def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: @@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li node_id=tool_node._node_id, tool_runtime=ToolRuntimeHandle(raw=object()), ) - return _collect_events(generator) + events = _collect_events(generator) + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert completed_events + return events, completed_events[-1].node_run_result.llm_usage def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", @@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index c8ddc53284..e3b5e3b591 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,10 +1,10 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.runtime import GraphRuntimeState from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -27,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": TRIGGER_PLUGIN_NODE_TYPE, - "title": "Trigger Event", - "plugin_id": "plugin-id", - "provider_id": "provider-id", - "event_name": "event-name", - "subscription_id": "subscription-id", - "plugin_unique_identifier": "plugin-unique-identifier", - "event_parameters": {}, - }, - } +def _build_node_data() -> TriggerEventNodeData: + return TriggerEventNodeData( + type=TRIGGER_PLUGIN_NODE_TYPE, + title="Trigger Event", + plugin_id="plugin-id", + provider_id="provider-id", + event_name="event-name", + subscription_id="subscription-id", + plugin_unique_identifier="plugin-unique-identifier", + event_parameters={}, ) def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: init_params, runtime_state = _build_context(graph_config={}) node = TriggerEventNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 1bbc12b23f..07d03bec05 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -30,11 +30,6 @@ def create_webhook_node( tenant_id: str = "test-tenant", ) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "webhook-node-1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="test-workflow", graph_config={}, @@ -56,8 +51,8 @@ def create_webhook_node( ) node = TriggerWebhookNode( - id="webhook-node-1", - config=node_config, + node_id="webhook-node-1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -66,10 +61,6 @@ def create_webhook_node( runtime_state.app_config = Mock() runtime_state.app_config.tenant_id = tenant_id - # Provide compatibility alias expected by node implementation - # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests - node.node_id = node.id - return node diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 427afa96ec..b839490d3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -24,11 +24,6 @@ from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="1", graph_config={}, @@ -48,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) start_at=0, ) node = TriggerWebhookNode( - id="1", - config=node_config, + node_id="1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -57,9 +52,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) # Provide tenant_id for conversion path runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() - # Compatibility alias for some nodes referencing `self.node_id` - node.node_id = node.id - return node @@ -225,7 +217,7 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="image.jpg", @@ -234,7 +226,7 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", filename="document.pdf", @@ -269,8 +261,19 @@ def test_webhook_node_run_with_file_params(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() @@ -284,7 +287,7 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="test.jpg", @@ -317,8 +320,19 @@ def test_webhook_node_run_mixed_parameters(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py new file mode 100644 index 0000000000..8b5fceeb37 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py @@ -0,0 +1,350 @@ +from types import SimpleNamespace + +import pytest +from pydantic import BaseModel + +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + adapt_human_input_node_data_for_graph, + adapt_node_config_for_graph, + adapt_node_data_for_graph, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert " Team") - html = EmailDeliveryConfig.render_markdown_body( - "**Hello** [mail](mailto:test@example.com)" - ) - - assert rendered == "Open https://example.com and use 42" - assert sanitized == "Hello alert(1) Team" - assert "Hello" in html - assert " + +` + +export const getChromePluginContent = (iframeUrl: string) => `ChatBot URL: ${iframeUrl}` + +export const compressAndEncodeBase64 = async (input: string) => { + const uint8Array = new TextEncoder().encode(input) + if (typeof CompressionStream === 'undefined') + return btoa(String.fromCharCode(...uint8Array)) + + const compressedStream = new Response( + new Blob([uint8Array]) + .stream() + .pipeThrough(new CompressionStream('gzip')), + ).arrayBuffer() + const compressedUint8Array = new Uint8Array(await compressedStream) + return btoa(String.fromCharCode(...compressedUint8Array)) +} + export const getAppCardDisplayState = ({ appInfo, cardType, diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index f02929c823..fc0132c45a 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -1,13 +1,14 @@ 'use client' +import type { WorkflowLaunchInputValue } from './app-card-utils' import type { ConfigParams } from './settings' import type { AppDetailResponse } from '@/models/app' import type { AppSSO } from '@/types/app' +import { Switch } from '@langgenius/dify-ui/switch' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import AppBasic from '@/app/components/app-sidebar/basic' import { useStore as useAppStore } from '@/app/components/app/store' -import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-button' import Indicator from '@/app/components/header/indicator' @@ -27,11 +28,16 @@ import { AppCardOperations, AppCardUrlSection, createAppCardOperations, + WorkflowLaunchDialog, } from './app-card-sections' import { + buildWorkflowLaunchUrl, + createWorkflowLaunchInitialValues, getAppCardDisplayState, getAppCardOperationKeys, + getAppHiddenLaunchVariables, isAppAccessConfigured, + isWorkflowLaunchInputSupported, } from './app-card-utils' export type IAppCardProps = { @@ -62,7 +68,8 @@ function AppCard({ const router = useRouter() const pathname = usePathname() const { isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() - const { data: currentWorkflow } = useAppWorkflow(appInfo.mode === AppModeEnum.WORKFLOW ? appInfo.id : '') + const shouldFetchWorkflow = appInfo.mode === AppModeEnum.WORKFLOW || appInfo.mode === AppModeEnum.ADVANCED_CHAT + const { data: currentWorkflow } = useAppWorkflow(shouldFetchWorkflow ? appInfo.id : '') const docLink = useDocLink() const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) @@ -72,6 +79,8 @@ function AppCard({ const [genLoading, setGenLoading] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) const [showAccessControl, setShowAccessControl] = useState(false) + const [showWorkflowLaunchDialog, setShowWorkflowLaunchDialog] = useState(false) + const [workflowLaunchValues, setWorkflowLaunchValues] = useState>({}) const { t } = useTranslation() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const { data: appAccessSubjects } = useAppWhiteListSubjects( @@ -97,6 +106,25 @@ function AppCard({ () => isAppAccessConfigured(appDetail, appAccessSubjects), [appAccessSubjects, appDetail], ) + const hiddenLaunchVariables = useMemo( + () => getAppHiddenLaunchVariables({ + appInfo, + currentWorkflow, + }) || [], + [appInfo, currentWorkflow], + ) + const supportedWorkflowLaunchVariables = useMemo( + () => hiddenLaunchVariables.filter(isWorkflowLaunchInputSupported), + [hiddenLaunchVariables], + ) + const unsupportedWorkflowLaunchVariables = useMemo( + () => hiddenLaunchVariables.filter(variable => !isWorkflowLaunchInputSupported(variable)), + [hiddenLaunchVariables], + ) + const initialWorkflowLaunchValues = useMemo( + () => createWorkflowLaunchInitialValues(supportedWorkflowLaunchVariables), + [supportedWorkflowLaunchVariables], + ) const onGenCode = async () => { if (!onGenerateCode) @@ -138,6 +166,31 @@ function AppCard({ window.open(cardState.accessibleUrl, '_blank') }, [cardState.accessibleUrl]) + const handleOpenWorkflowLaunchDialog = useCallback(() => { + setWorkflowLaunchValues(initialWorkflowLaunchValues) + setShowWorkflowLaunchDialog(true) + }, [initialWorkflowLaunchValues]) + + const handleWorkflowLaunchValueChange = useCallback((variable: string, value: WorkflowLaunchInputValue) => { + setWorkflowLaunchValues(prev => ({ + ...prev, + [variable]: value, + })) + }, []) + + const handleWorkflowLaunchConfirm = useCallback(async (event: React.FormEvent) => { + event.preventDefault() + + const targetUrl = await buildWorkflowLaunchUrl({ + accessibleUrl: cardState.accessibleUrl, + variables: supportedWorkflowLaunchVariables, + values: workflowLaunchValues, + }) + + window.open(targetUrl, '_blank') + setShowWorkflowLaunchDialog(false) + }, [cardState.accessibleUrl, supportedWorkflowLaunchVariables, workflowLaunchValues]) + const handleOpenCustomize = useCallback(() => { setShowCustomizeModal(true) }, []) @@ -279,7 +332,17 @@ function AppCard({ {!cardState.isMinimalState && (
{!isApp && } - + 0 + ? { + label: t('operation.config', { ns: 'common' }), + disabled: triggerModeDisabled || !cardState.runningStatus, + onClick: handleOpenWorkflowLaunchDialog, + } + : undefined} + />
)} @@ -298,6 +361,17 @@ function AppCard({ onCloseAccessControl={() => setShowAccessControl(false)} onSaveSiteConfig={onSaveSiteConfig} onConfirmAccessControl={handleAccessControlUpdate} + hiddenInputs={hiddenLaunchVariables} + /> + ) diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index b1d7066779..e4b8aabf69 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC } from 'react' import { ArrowTopRightOnSquareIcon } from '@heroicons/react/24/outline' +import { Button } from '@langgenius/dify-ui/button' import * as React from 'react' import { useTranslation } from 'react-i18next' import Modal from '@/app/components/base/modal' import Tag from '@/app/components/base/tag' -import { Button } from '@/app/components/base/ui/button' import { useDocLink } from '@/context/i18n' import { AppModeEnum } from '@/types/app' diff --git a/web/app/components/app/overview/embedded/__tests__/index.spec.tsx b/web/app/components/app/overview/embedded/__tests__/index.spec.tsx index 0a843c26fd..ac4574b5f4 100644 --- a/web/app/components/app/overview/embedded/__tests__/index.spec.tsx +++ b/web/app/components/app/overview/embedded/__tests__/index.spec.tsx @@ -1,10 +1,11 @@ import type { SiteInfo } from '@/models/share' -import { fireEvent, render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' import copy from 'copy-to-clipboard' import * as React from 'react' - import { act } from 'react' -import { afterAll, afterEach, describe, expect, it, vi } from 'vitest' + +import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { InputVarType } from '@/app/components/workflow/types' import Embedded from '../index' vi.mock('../style.module.css', () => ({ @@ -46,6 +47,7 @@ vi.mock('@/context/app-context', () => ({ })) const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null) const mockedCopy = vi.mocked(copy) +const originalCompressionStream = globalThis.CompressionStream const siteInfo: SiteInfo = { title: 'test site', @@ -70,6 +72,22 @@ const getCopyButton = () => { } describe('Embedded', () => { + beforeAll(() => { + class MockCompressionStream { + readable: ReadableStream + writable: WritableStream + + constructor() { + const transformStream = new TransformStream() + this.readable = transformStream.readable + this.writable = transformStream.writable + } + } + + // @ts-expect-error test polyfill + globalThis.CompressionStream = MockCompressionStream + }) + afterEach(() => { vi.clearAllMocks() mockWindowOpen.mockClear() @@ -77,6 +95,7 @@ describe('Embedded', () => { afterAll(() => { mockWindowOpen.mockRestore() + globalThis.CompressionStream = originalCompressionStream }) it('builds theme and copies iframe snippet', async () => { @@ -84,14 +103,20 @@ describe('Embedded', () => { render() }) + await waitFor(() => { + expect(screen.getByText((content, node) => node?.tagName.toLowerCase() === 'pre' && content.includes('/chatbot/token'))).toBeInTheDocument() + }) + const actionButton = getCopyButton() const innerDiv = actionButton.querySelector('div') - act(() => { + await act(async () => { fireEvent.click(innerDiv ?? actionButton) }) expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted) - expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token')) + await waitFor(() => { + expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token')) + }) }) it('opens chrome plugin store link when chrome option selected', async () => { @@ -116,4 +141,53 @@ describe('Embedded', () => { 'noopener,noreferrer', ) }) + + it('keeps hidden inputs collapsed by default and updates iframe and script content when values change', async () => { + render( + , + ) + + expect(screen.queryByLabelText('Secret')).not.toBeInTheDocument() + + await act(async () => { + fireEvent.click(screen.getByText('appOverview.overview.appInfo.embedded.hiddenInputs.title').closest('button')!) + }) + + await waitFor(() => { + expect(screen.getByLabelText('Secret')).toBeInTheDocument() + }) + + await act(async () => { + fireEvent.change(screen.getByLabelText('Secret'), { + target: { value: 'top-secret' }, + }) + }) + + expect(document.querySelector('pre')?.textContent ?? '').toContain('/chatbot/token') + + await waitFor(() => { + const codeBlock = document.querySelector('pre') + expect(codeBlock?.textContent ?? '').toContain('/chatbot/token?secret=dG9wLXNlY3JldA%3D%3D') + }) + + const optionButtons = document.body.querySelectorAll('[class*="option"]') + act(() => { + fireEvent.click(optionButtons[1]) + }) + + await waitFor(() => { + const codeBlock = document.querySelector('pre') + expect(codeBlock?.textContent ?? '').toContain('secret: "top-secret"') + }) + }) }) diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index 029566587e..f19f3aba80 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -1,90 +1,48 @@ +import type { MutableRefObject } from 'react' +import type { WorkflowHiddenStartVariable, WorkflowLaunchInputValue } from '../app-card-utils' import type { SiteInfo } from '@/models/share' import { cn } from '@langgenius/dify-ui/cn' import { + RiArrowDownSLine, + RiArrowRightSLine, RiClipboardFill, RiClipboardLine, } from '@remixicon/react' import copy from 'copy-to-clipboard' -import * as React from 'react' -import { useEffect, useState } from 'react' +import { Suspense, use, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import { useThemeContext } from '@/app/components/base/chat/embedded-chatbot/theme/theme-context' import Modal from '@/app/components/base/modal' import Tooltip from '@/app/components/base/tooltip' -import { IS_CE_EDITION } from '@/config' +import { InputVarType } from '@/app/components/workflow/types' import { useAppContext } from '@/context/app-context' import { basePath } from '@/utils/var' +import { + compressAndEncodeBase64, + createWorkflowLaunchInitialValues, + getChromePluginContent, + getEmbeddedIframeSnippet, + getEmbeddedScriptSnippet, + isWorkflowLaunchInputSupported, +} from '../app-card-utils' +import WorkflowHiddenInputFields from '../workflow-hidden-input-fields' import style from './style.module.css' type Props = { siteInfo?: SiteInfo isShow: boolean onClose: () => void - accessToken: string - appBaseUrl: string + accessToken?: string + appBaseUrl?: string + hiddenInputs?: WorkflowHiddenStartVariable[] className?: string } -const OPTION_MAP = { - iframe: { - getContent: (url: string, token: string) => - ``, - }, - scripts: { - getContent: (url: string, token: string, primaryColor: string, isTestEnv?: boolean) => - ` - -`, - }, - chromePlugin: { - getContent: (url: string, token: string) => `ChatBot URL: ${url}${basePath}/chatbot/${token}`, - }, -} +const OPTION_KEYS = ['iframe', 'scripts', 'chromePlugin'] as const const prefixEmbedded = 'overview.appInfo.embedded' -type Option = keyof typeof OPTION_MAP +type Option = typeof OPTION_KEYS[number] type OptionStatus = { iframe: boolean @@ -92,58 +50,192 @@ type OptionStatus = { chromePlugin: boolean } -const Embedded = ({ siteInfo, isShow, onClose, appBaseUrl, accessToken, className }: Props) => { +const EMPTY_OPTION_STATUS: OptionStatus = { + iframe: false, + scripts: false, + chromePlugin: false, +} + +const getSerializedHiddenInputValue = ( + variable: WorkflowHiddenStartVariable, + values: Record, +) => { + const rawValue = values[variable.variable] + if (variable.type === InputVarType.checkbox) + return String(Boolean(rawValue)) + + return String(rawValue ?? '') +} + +const buildEmbeddedIframeUrl = async ({ + appBaseUrl, + accessToken, + variables, + values, +}: { + appBaseUrl: string + accessToken: string + variables: WorkflowHiddenStartVariable[] + values: Record +}) => { + const iframeUrl = new URL(`${appBaseUrl}${basePath}/chatbot/${accessToken}`, window.location.origin) + + await Promise.all(variables.map(async (variable) => { + iframeUrl.searchParams.set(variable.variable, await compressAndEncodeBase64(getSerializedHiddenInputValue(variable, values))) + })) + + return iframeUrl.toString() +} + +const AsyncEmbeddedOptionContent = ({ + option, + iframeUrlPromise, + latestResolvedIframeUrlRef, +}: { + option: Option + iframeUrlPromise: Promise + latestResolvedIframeUrlRef: MutableRefObject +}) => { + const iframeUrl = use(iframeUrlPromise) + latestResolvedIframeUrlRef.current = iframeUrl + + if (option === 'chromePlugin') + return getChromePluginContent(iframeUrl) + + return getEmbeddedIframeSnippet(iframeUrl) +} + +const EmbeddedContent = ({ + siteInfo, + appBaseUrl, + accessToken, + hiddenInputs, +}: Required> & Pick) => { const { t } = useTranslation() + const supportedHiddenInputs = useMemo( + () => (hiddenInputs ?? []).filter(isWorkflowLaunchInputSupported), + [hiddenInputs], + ) + const initialHiddenInputValues = useMemo( + () => createWorkflowLaunchInitialValues(supportedHiddenInputs), + [supportedHiddenInputs], + ) const [option, setOption] = useState