diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index 82d9c21cbb..5884830569 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -20,11 +20,11 @@ ```typescript // ❌ WRONG: Don't mock base components vi.mock('@/app/components/base/loading', () => () =>
Loading
) -vi.mock('@/app/components/base/ui/button', () => ({ children }: any) => ) +vi.mock('@langgenius/dify-ui/button', () => ({ children }: any) => ) // ✅ CORRECT: Import and use real base components import Loading from '@/app/components/base/loading' -import { Button } from '@/app/components/base/ui/button' +import { Button } from '@langgenius/dify-ui/button' // They will render normally in tests ``` diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index eefb1ebbb9..c55b013dbe 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -76,13 +76,11 @@ jobs: diff += '\\n\\n... (truncated) ...'; } - const body = diff.trim() - ? '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
' - : '### Pyrefly Diff\nNo changes detected.'; - - await github.rest.issues.createComment({ - issue_number: prNumber, - owner: context.repo.owner, - repo: context.repo.repo, - body, - }); + if (diff.trim()) { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body: '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
', + }); + } diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index f3ab4c62c7..2a5cf19645 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -89,3 +89,37 @@ jobs: flags: web env: CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} + + dify-ui-test: + name: dify-ui Tests + runs-on: ubuntu-latest + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + defaults: + run: + shell: bash + working-directory: ./packages/dify-ui + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web environment + uses: ./.github/actions/setup-web + + - name: Install Chromium for Browser Mode + run: vp exec playwright install --with-deps chromium + + - name: Run dify-ui tests + run: vp test run --coverage --silent=passed-only + + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 + with: + directory: packages/dify-ui/coverage + flags: dify-ui + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} diff --git a/api/commands/account.py b/api/commands/account.py index 6a2a2e0428..761323a73d 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,6 +2,7 @@ import base64 import secrets import click +from sqlalchemy.orm import Session from constants.languages import languages from extensions.ext_database import db @@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm): # encrypt password with salt password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = db.session.merge(account) - account.password = base64_password_hashed - account.password_salt = base64_salt - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + session.commit() AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm): click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account = db.session.merge(account) - account.email = normalized_new_email - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.email = normalized_new_email + session.commit() click.echo(click.style("Email updated successfully.", fg="green")) diff --git a/api/constants/dsl_version.py b/api/constants/dsl_version.py new file mode 100644 index 0000000000..b0fbe0075c --- /dev/null +++ b/api/constants/dsl_version.py @@ -0,0 +1 @@ +CURRENT_APP_DSL_VERSION = "0.6.0" diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 23351beed9..7302a4edf5 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -126,8 +126,6 @@ from .snippets import snippet_workflow, snippet_workflow_draft_variable from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport] # Import snippet controllers -from .snippets import snippet_workflow, snippet_workflow_draft_variable - # Import tag controllers from .tag import tags @@ -215,12 +213,12 @@ __all__ = [ "setup", "site", "snippet_workflow", - "snippet_workflow_draft_variable", - "snippets", - "socketio_workflow", "snippet_workflow", "snippet_workflow_draft_variable", + "snippet_workflow_draft_variable", "snippets", + "snippets", + "socketio_workflow", "spec", "statistic", "tags", diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index cead33d14f..9c8b095b9f 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel): def _normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): return value try: diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index f6319573e0..e32ba5f66c 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable): def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: value_type = workflow_draft_var.value_type - return value_type.exposed_type().value + return str(value_type.exposed_type()) class FullContentDict(TypedDict): @@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict result: FullContentDict = { "size_bytes": variable_file.size, - "value_type": variable_file.value_type.exposed_type().value, + "value_type": str(variable_file.value_type.exposed_type()), "length": variable_file.length, "download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True), } @@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.exposed_type().value, + "value_type": str(v.value_type.exposed_type()), "value": v.value, # Do not track edited for env vars. "edited": False, diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 50851aea08..47e1889c95 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel): def normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): try: - return str(SegmentType(value).exposed_type().value) + return str(SegmentType(value).exposed_type()) except ValueError: return value try: diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 06c746990d..c22102c2ba 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,20 +4,6 @@ import uuid from decimal import Decimal from typing import Union, cast -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity @@ -43,6 +29,20 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index fdffde85d0..b31fde5fb2 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -300,7 +300,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): # update prompt tool for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + tool_instance = tool_instances.get(prompt_tool.name) + if tool_instance: + self.update_prompt_message_tool(tool_instance, prompt_tool) iteration_step += 1 diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index b7dd55632e..5df3df2b3e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,14 +1,13 @@ from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a20d3f3c38..cae0eee0df 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,9 +1,6 @@ import logging from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -19,6 +16,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index dfe6133cb6..e2e07ebaff 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, TextPromptMessageContent, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 68e5e5f0c8..3a6f9d575a 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage from graphon.file import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.protocols import WorkflowFileRuntimeProtocol from graphon.file.runtime import set_workflow_file_runtime +from graphon.http.protocols import HttpResponseProtocol if TYPE_CHECKING: from graphon.file import File @@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): return dify_config.MULTIMODAL_SEND_FORMAT def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - return ssrf_proxy.get(url, follow_redirects=follow_redirects) + return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index ada065a943..602940af6e 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -350,7 +350,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs - execution.exceptions_count = runtime_state.exceptions_count + execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count) def _update_node_execution( self, diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index a5297fa33a..4cb1905688 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -352,11 +352,11 @@ class DatasourceManager: raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") file_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.CUSTOM, + file_type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference(record_id=str(upload_file.id)), diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index d07f6f913a..38b87e2cd1 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -8,16 +8,6 @@ from collections.abc import Iterator, Sequence from json import JSONDecodeError from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -34,6 +24,16 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel): else [], ) - def validate_provider_credentials( - self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None - ): + def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""): """ Validate custom credentials. :param credentials: provider credentials :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate - :param session: optional database session :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) - credential_record = s.execute(stmt).scalar_one_or_none() - # fix origin data + credential_record = session.execute(stmt).scalar_one_or_none() if credential_record and credential_record.encrypted_config: if not credential_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": credential_record.encrypted_config} @@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def _generate_provider_credential_name(self, session) -> str: """ @@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: - if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session): raise ValueError(f"Credential with name '{credential_name}' already exists.") else: - credential_name = self._generate_provider_credential_name(session) + credential_name = self._generate_provider_credential_name(pre_session) - credentials = self.validate_provider_credentials(credentials=credentials, session=session) + credentials = self.validate_provider_credentials(credentials=credentials) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) try: new_record = ProviderCredential( @@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel): session.flush() if not provider_record: - # If provider record does not exist, create it provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, @@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_provider_credential_name_exists( - credential_name=credential_name, session=session, exclude_id=credential_id + credential_name=credential_name, session=pre_session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") - credentials = self.validate_provider_credentials( - credentials=credentials, credential_id=credential_id, session=session - ) + credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, @@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel): ProviderCredential.provider_name.in_(self._get_provider_names()), ) - # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: @@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel): model: str, credentials: dict[str, Any], credential_id: str = "", - session: Session | None = None, ): """ Validate custom model credentials. @@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel): :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, @@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type, ) - credential_record = s.execute(stmt).scalar_one_or_none() + credential_record = session.execute(stmt).scalar_one_or_none() original_credentials = ( json.loads(credential_record.encrypted_config) if credential_record and credential_record.encrypted_config @@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def create_custom_model_credential( self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None @@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel): :param credentials: model credentials dict :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: if self._check_custom_model_credential_name_exists( - model=model, model_type=model_type, credential_name=credential_name, session=session + model=model, model_type=model_type, credential_name=credential_name, session=pre_session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") else: credential_name = self._generate_custom_model_credential_name( - model=model, model_type=model_type, session=session + model=model, model_type=model_type, session=pre_session ) - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, model=model, credentials=credentials, session=session - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) try: @@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel): session.add(credential) session.flush() - # save provider model if not provider_model_record: provider_model_record = ProviderModel( tenant_id=self.tenant_id, @@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel): :param credential_id: credential id :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, - session=session, + session=pre_session, exclude_id=credential_id, ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, - model=model, - credentials=credentials, - credential_id=credential_id, - session=session, - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) stmt = select(ProviderModelCredential).where( @@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel): raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b96a9ce380..38864a1830 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -102,7 +102,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() + inputs_json_str = dumps_with_segments(inputs).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index a1e782a094..f169f247cf 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,14 +2,13 @@ import logging import secrets from typing import cast -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index e38592bb7b..91e92712b7 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client from core.tools.errors import ToolSSRFError +from graphon.http.response import HttpResponse logger = logging.getLogger(__name__) @@ -267,4 +268,47 @@ class SSRFProxy: return patch(url=url, max_retries=max_retries, **kwargs) +def _to_graphon_http_response(response: httpx.Response) -> HttpResponse: + """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper.""" + return HttpResponse( + status_code=response.status_code, + headers=dict(response.headers), + content=response.content, + url=str(response.url) if response.url else None, + reason_phrase=response.reason_phrase, + fallback_text=response.text, + ) + + +class GraphonSSRFProxy: + """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``.""" + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs)) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs)) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs)) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs)) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs)) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs)) + + ssrf_proxy = SSRFProxy() +graphon_ssrf_proxy = GraphonSSRFProxy() diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 36beb55d7f..86d0e3baaa 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,20 +1,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence -from typing import IO, Any, Literal, Optional, Union, cast, overload - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from configs import dify_config from core.entities import PluginCredentialType @@ -25,9 +11,24 @@ from core.errors.error import ProviderTokenNotInitError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_redis import redis_client +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from models.provider import ProviderType logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") class ModelInstance: @@ -169,7 +170,7 @@ class ModelInstance: return cast( Union[LLMResult, Generator], self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -194,7 +195,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -214,7 +215,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, texts=texts, @@ -236,7 +237,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, @@ -253,7 +254,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, texts=texts, @@ -278,7 +279,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, query=query, @@ -306,7 +307,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke_multimodal_rerank, + self.model_type_instance.invoke_multimodal_rerank, model=self.model_name, credentials=self.credentials, query=query, @@ -325,7 +326,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, text=text, @@ -341,7 +342,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, file=file, @@ -358,14 +359,14 @@ class ModelInstance: if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, content_text=content_text, voice=voice, ) - def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Round-robin invoke :param function: function to invoke diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index fda00ac3b9..d78ce90aa1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -1,8 +1,8 @@ from enum import StrEnum -from pydantic import BaseModel, ValidationInfo, field_validator +from pydantic import BaseModel -from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import validate_project_name, validate_url class TracingProviderEnum(StrEnum): @@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel): return validate_project_name(v, default_name) -class ArizeConfig(BaseTracingConfig): - """ - Model class for Arize tracing config. - """ - - api_key: str | None = None - space_id: str | None = None - project: str | None = None - endpoint: str = "https://otlp.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://otlp.arize.com") - - -class PhoenixConfig(BaseTracingConfig): - """ - Model class for Phoenix tracing config. - """ - - api_key: str | None = None - project: str | None = None - endpoint: str = "https://app.phoenix.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://app.phoenix.arize.com") - - -class LangfuseConfig(BaseTracingConfig): - """ - Model class for Langfuse tracing config. - """ - - public_key: str - secret_key: str - host: str = "https://api.langfuse.com" - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://api.langfuse.com") - - -class LangSmithConfig(BaseTracingConfig): - """ - Model class for Langsmith tracing config. - """ - - api_key: str - project: str - endpoint: str = "https://api.smith.langchain.com" - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # LangSmith only allows HTTPS - return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) - - -class OpikConfig(BaseTracingConfig): - """ - Model class for Opik tracing config. - """ - - api_key: str | None = None - project: str | None = None - workspace: str | None = None - url: str = "https://www.comet.com/opik/api/" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "Default Project") - - @field_validator("url") - @classmethod - def url_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") - - -class WeaveConfig(BaseTracingConfig): - """ - Model class for Weave tracing config. - """ - - api_key: str - entity: str | None = None - project: str - endpoint: str = "https://trace.wandb.ai" - host: str | None = None - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # Weave only allows HTTPS for endpoint - return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - if v is not None and v.strip() != "": - return validate_url(v, v, allowed_schemes=("https", "http")) - return v - - -class AliyunConfig(BaseTracingConfig): - """ - Model class for Aliyun tracing config. - """ - - app_name: str = "dify_app" - license_key: str - endpoint: str - - @field_validator("app_name") - @classmethod - def app_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - @field_validator("license_key") - @classmethod - def license_key_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("License key cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # aliyun uses two URL formats, which may include a URL path - return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") - - -class TencentConfig(BaseTracingConfig): - """ - Tencent APM tracing config - """ - - token: str - endpoint: str - service_name: str - - @field_validator("token") - @classmethod - def token_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("Token cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") - - @field_validator("service_name") - @classmethod - def service_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - -class MLflowConfig(BaseTracingConfig): - """ - Model class for MLflow tracing config. - """ - - tracking_uri: str = "http://localhost:5000" - experiment_id: str = "0" # Default experiment id in MLflow is 0 - username: str | None = None - password: str | None = None - - @field_validator("tracking_uri") - @classmethod - def tracking_uri_validator(cls, v, info: ValidationInfo): - if isinstance(v, str) and v.startswith("databricks"): - raise ValueError( - "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." - ) - return validate_url_with_path(v, "http://localhost:5000") - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - -class DatabricksConfig(BaseTracingConfig): - """ - Model class for Databricks (Databricks-managed MLflow) tracing config. - """ - - experiment_id: str - host: str - client_id: str | None = None - client_secret: str | None = None - personal_access_token: str | None = None - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index cd63951537..e7ba6e502b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict): class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): def __getitem__(self, provider: str) -> TracingProviderConfigEntry: - match provider: - case TracingProviderEnum.LANGFUSE: - from core.ops.entities.config_entity import LangfuseConfig - from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + try: + match provider: + case TracingProviderEnum.LANGFUSE: + from dify_trace_langfuse.config import LangfuseConfig + from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace - return { - "config_class": LangfuseConfig, - "secret_keys": ["public_key", "secret_key"], - "other_keys": ["host", "project_key"], - "trace_instance": LangFuseDataTrace, - } + return { + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, + } - case TracingProviderEnum.LANGSMITH: - from core.ops.entities.config_entity import LangSmithConfig - from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace + case TracingProviderEnum.LANGSMITH: + from dify_trace_langsmith.config import LangSmithConfig + from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace - return { - "config_class": LangSmithConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": LangSmithDataTrace, - } + return { + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + } - case TracingProviderEnum.OPIK: - from core.ops.entities.config_entity import OpikConfig - from core.ops.opik_trace.opik_trace import OpikDataTrace + case TracingProviderEnum.OPIK: + from dify_trace_opik.config import OpikConfig + from dify_trace_opik.opik_trace import OpikDataTrace - return { - "config_class": OpikConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "url", "workspace"], - "trace_instance": OpikDataTrace, - } + return { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + } - case TracingProviderEnum.WEAVE: - from core.ops.entities.config_entity import WeaveConfig - from core.ops.weave_trace.weave_trace import WeaveDataTrace + case TracingProviderEnum.WEAVE: + from dify_trace_weave.config import WeaveConfig + from dify_trace_weave.weave_trace import WeaveDataTrace - return { - "config_class": WeaveConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint", "host"], - "trace_instance": WeaveDataTrace, - } - case TracingProviderEnum.ARIZE: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import ArizeConfig + return { + "config_class": WeaveConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "entity", "endpoint", "host"], + "trace_instance": WeaveDataTrace, + } + case TracingProviderEnum.ARIZE: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import ArizeConfig - return { - "config_class": ArizeConfig, - "secret_keys": ["api_key", "space_id"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.PHOENIX: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import PhoenixConfig + return { + "config_class": ArizeConfig, + "secret_keys": ["api_key", "space_id"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.PHOENIX: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import PhoenixConfig - return { - "config_class": PhoenixConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.ALIYUN: - from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace - from core.ops.entities.config_entity import AliyunConfig + return { + "config_class": PhoenixConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.ALIYUN: + from dify_trace_aliyun.aliyun_trace import AliyunDataTrace + from dify_trace_aliyun.config import AliyunConfig - return { - "config_class": AliyunConfig, - "secret_keys": ["license_key"], - "other_keys": ["endpoint", "app_name"], - "trace_instance": AliyunDataTrace, - } - case TracingProviderEnum.MLFLOW: - from core.ops.entities.config_entity import MLflowConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": AliyunConfig, + "secret_keys": ["license_key"], + "other_keys": ["endpoint", "app_name"], + "trace_instance": AliyunDataTrace, + } + case TracingProviderEnum.MLFLOW: + from dify_trace_mlflow.config import MLflowConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": MLflowConfig, - "secret_keys": ["password"], - "other_keys": ["tracking_uri", "experiment_id", "username"], - "trace_instance": MLflowDataTrace, - } - case TracingProviderEnum.DATABRICKS: - from core.ops.entities.config_entity import DatabricksConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": MLflowConfig, + "secret_keys": ["password"], + "other_keys": ["tracking_uri", "experiment_id", "username"], + "trace_instance": MLflowDataTrace, + } + case TracingProviderEnum.DATABRICKS: + from dify_trace_mlflow.config import DatabricksConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": DatabricksConfig, - "secret_keys": ["personal_access_token", "client_secret"], - "other_keys": ["host", "client_id", "experiment_id"], - "trace_instance": MLflowDataTrace, - } + return { + "config_class": DatabricksConfig, + "secret_keys": ["personal_access_token", "client_secret"], + "other_keys": ["host", "client_id", "experiment_id"], + "trace_instance": MLflowDataTrace, + } - case TracingProviderEnum.TENCENT: - from core.ops.entities.config_entity import TencentConfig - from core.ops.tencent_trace.tencent_trace import TencentDataTrace + case TracingProviderEnum.TENCENT: + from dify_trace_tencent.config import TencentConfig + from dify_trace_tencent.tencent_trace import TencentDataTrace - return { - "config_class": TencentConfig, - "secret_keys": ["token"], - "other_keys": ["endpoint", "service_name"], - "trace_instance": TencentDataTrace, - } + return { + "config_class": TencentConfig, + "secret_keys": ["token"], + "other_keys": ["endpoint", "service_name"], + "trace_instance": TencentDataTrace, + } - case _: - raise KeyError(f"Unsupported tracing provider: {provider}") + case _: + raise KeyError(f"Unsupported tracing provider: {provider}") + except ImportError: + raise ImportError(f"Provider {provider} is not installed.") provider_config_map = OpsTraceProviderConfigMap() diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index e3fba4ef3a..4e66d58b5e 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime): if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") file_name = ( - provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us ) elif icon_type.lower() == "icon_small_dark": if not provider_schema.icon_small_dark: raise ValueError(f"Provider {provider} does not have small dark icon.") file_name = ( - provider_schema.icon_small_dark.zh_Hans + provider_schema.icon_small_dark.zh_hans if lang.lower() == "zh_hans" - else provider_schema.icon_small_dark.en_US + else provider_schema.icon_small_dark.en_us ) else: raise ValueError(f"Unsupported icon type: {icon_type}.") diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 9be70199b7..126888f739 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -5,7 +5,7 @@ from graphon.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 9f1c73ec88..a9995778f7 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,8 +4,6 @@ import pickle from typing import Any, cast import numpy as np -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -15,6 +13,8 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 052fca930d..0330a43b28 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -3,6 +3,7 @@ Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). """ +import inspect import logging import mimetypes import os @@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor): file_path: Path to the file to load. """ + _closed: bool + def __init__(self, file_path: str, tenant_id: str, user_id: str): """Initialize with file path.""" + self._closed = False self.file_path = file_path self.tenant_id = tenant_id self.user_id = user_id @@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor): elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") + def close(self) -> None: + """Best-effort cleanup for downloaded temporary files.""" + if getattr(self, "_closed", False): + return + + self._closed = True + temp_file = getattr(self, "temp_file", None) + if temp_file is None: + return + + try: + close_result = temp_file.close() + if inspect.isawaitable(close_result): + close_awaitable = getattr(close_result, "close", None) + if callable(close_awaitable): + close_awaitable() + except Exception: + logger.debug("Failed to cleanup downloaded word temp file", exc_info=True) + def __del__(self): - if hasattr(self, "temp_file"): - self.temp_file.close() + self.close() def extract(self) -> list[Document]: """Load given path as single page.""" diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a487c49053..ace0b791e8 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): try: # Create File object directly (similar to DatasetRetrieval) file_obj = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8ebc840b99..5631b3a921 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,11 +9,6 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select, update from sqlalchemy.orm import sessionmaker @@ -69,6 +64,11 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile @@ -517,11 +517,11 @@ class DatasetRetrieval: if attachments_with_bindings: for _, upload_file in attachments_with_bindings: attachment_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 3383c7f3bd..66b375dad1 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -7,10 +7,9 @@ import re from collections.abc import Collection from typing import Any, Literal -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer - from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter +from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 02625e242f..740d727e26 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -8,7 +8,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index d8674b3af9..ada3ed7f43 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -28,7 +28,7 @@ class ToolFileManager: def _build_graph_file_reference(tool_file: ToolFile) -> File: extension = guess_extension(tool_file.mimetype) or ".bin" return File( - type=get_file_type_by_mime_type(tool_file.mimetype), + file_type=get_file_type_by_mime_type(tool_file.mimetype), transfer_method=FileTransferMethod.TOOL_FILE, remote_url=tool_file.original_url, reference=build_file_reference(record_id=str(tool_file.id)), diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index be13d40f3e..31964b6702 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1083,7 +1083,12 @@ class ToolManager: continue tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + variable_selector = tool_input.value + if not isinstance(variable_selector, list) or not all( + isinstance(selector_part, str) for selector_part in variable_selector + ): + raise ToolParameterError("Variable tool input must be a variable selector") + variable = variable_pool.get(variable_selector) if variable is None: raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8d6f83dc07..e534a9415a 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -18,7 +18,7 @@ from graphon.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder from core.model_manager import ModelManager diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 7c4f8ee03a..e510cfcead 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -357,7 +357,10 @@ class WorkflowTool(Tool): def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]: file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) - transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) + transfer_method_value = file_dict.get("transfer_method") + if not isinstance(transfer_method_value, str): + raise ValueError("Workflow file mapping is missing a valid transfer_method") + transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.TOOL_FILE: file_dict["tool_file_id"] = file_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py similarity index 74% rename from api/core/workflow/human_input_compat.py rename to api/core/workflow/human_input_adapter.py index 75a0a0c202..4b765e6aea 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_adapter.py @@ -1,8 +1,8 @@ -"""Workflow-layer adapters for legacy human-input payload keys. +"""Workflow-to-Graphon adapters for persisted node payloads. -Stored workflow graphs and editor payloads may still use Dify-specific human -input recipient keys. Normalize them here before handing configs to -`graphon` so graph-owned models only see graph-neutral field names. +Stored workflow graphs and editor payloads still contain a small set of +Dify-owned field spellings and value shapes. Adapt them here before handing the +payload to Graphon so Graphon-owned models only see current contracts. """ from __future__ import annotations @@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None: return None -def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") @@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: - normalized = normalize_human_input_node_data_for_graph(node_data) + normalized = adapt_human_input_node_data_for_graph(node_data) raw_delivery_methods = normalized.get("delivery_methods") if not isinstance(raw_delivery_methods, list): return [] @@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b return False -def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") - if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: - return normalized - return normalize_human_input_node_data_for_graph(normalized) + node_type = normalized.get("type") + if node_type == BuiltinNodeTypes.HUMAN_INPUT: + return adapt_human_input_node_data_for_graph(normalized) + if node_type == BuiltinNodeTypes.TOOL: + return _adapt_tool_node_data_for_graph(normalized) + return normalized -def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_config) if normalized is None: raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") @@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) if data_mapping is None: return normalized - normalized["data"] = normalize_node_data_for_graph(data_mapping) + normalized["data"] = adapt_node_data_for_graph(data_mapping) return normalized +def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(node_data) + + raw_tool_configurations = normalized.get("tool_configurations") + if not isinstance(raw_tool_configurations, Mapping): + return normalized + + existing_tool_parameters = normalized.get("tool_parameters") + normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {} + normalized_tool_configurations: dict[str, Any] = {} + found_legacy_tool_inputs = False + + for name, value in raw_tool_configurations.items(): + if not isinstance(value, Mapping): + normalized_tool_configurations[name] = value + continue + + input_type = value.get("type") + input_value = value.get("value") + if input_type not in {"mixed", "variable", "constant"}: + normalized_tool_configurations[name] = value + continue + + found_legacy_tool_inputs = True + normalized_tool_parameters.setdefault(name, dict(value)) + + flattened_value = _flatten_legacy_tool_configuration_value( + input_type=input_type, + input_value=input_value, + ) + if flattened_value is not None: + normalized_tool_configurations[name] = flattened_value + + if not found_legacy_tool_inputs: + return normalized + + normalized["tool_parameters"] = normalized_tool_parameters + normalized["tool_configurations"] = normalized_tool_configurations + return normalized + + +def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None: + if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool): + return input_value + + if ( + input_type == "variable" + and isinstance(input_value, list) + and all(isinstance(item, str) for item in input_value) + ): + return "{{#" + ".".join(input_value) + "#}}" + + return None + + def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: normalized = dict(recipients) @@ -291,9 +349,9 @@ __all__ = [ "MemberRecipient", "WebAppDeliveryMethod", "_WebAppDeliveryConfig", + "adapt_human_input_node_data_for_graph", + "adapt_node_config_for_graph", + "adapt_node_data_for_graph", "is_human_input_webapp_enabled", - "normalize_human_input_node_data_for_graph", - "normalize_node_config_for_graph", - "normalize_node_data_for_graph", "parse_human_input_delivery_methods", ] diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 351da3444f..de4eae1b22 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, ) -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.trigger.constants import TRIGGER_NODE_TYPES -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.node_runtime import ( DifyFileReferenceFactory, DifyHumanInputNodeRuntime, @@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType from graphon.file.file_manager import file_manager from graphon.graph.graph import NodeFactory from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.nodes.base.node import Node from graphon.nodes.code.code_node import WorkflowCodeExecutor from graphon.nodes.code.entities import CodeLanguage @@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + """Resolve the production node class for the requested type/version.""" node_mapping = get_node_type_classes_mapping().get(node_type) if not node_mapping: raise ValueError(f"No class mapping found for node type: {node_type}") @@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory): ) self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH - self._http_request_http_client = ssrf_proxy + self._http_request_http_client = graphon_ssrf_proxy self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( self._dify_context, conversation_id_getter=self._conversation_id, @@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + # Graph configs are initially validated against permissive shared node data. + # Re-validate using the resolved node class so workflow-local node schemas + # stay explicit and constructors receive the concrete typed payload. + resolved_node_data = self._validate_resolved_node_data(node_class, node_data) node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { BuiltinNodeTypes.CODE: lambda: { @@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory): ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=False, include_llm_file_saver=False, @@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory): } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, **node_init_kwargs, @@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory): """ Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. """ - return node_class.validate_node_data(node_data) + validate_node_data = getattr(node_class, "validate_node_data", None) + if callable(validate_node_data): + return cast("BaseNodeData", validate_node_data(node_data)) + return node_data @staticmethod def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 2e632e56f0..b8725853c4 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload from sqlalchemy import select from sqlalchemy.orm import Session @@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import ( ) from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.llm.runtime_protocols import ( PreparedLLMProtocol, @@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from .human_input_compat import ( +from .human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, DeliveryMethodType, @@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: return self._model_instance.get_llm_num_tokens(prompt_messages) + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): stream=stream, ) + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + def invoke_llm_with_structured_output( self, *, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7b000101b0..68a24e86b1 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node @@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: AgentNodeData, + *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, - *, strategy_resolver: AgentStrategyResolver, presentation_provider: AgentStrategyPresentationProvider, runtime_support: AgentRuntimeSupport, message_transformer: AgentMessageTransformer, ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index d9247b2593..f3006c4242 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,7 +1,12 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, @@ -12,13 +17,6 @@ from graphon.node_events import NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceProviderType -from core.plugin.impl.exc import PluginDaemonClientSideError -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.system_variables import SystemVariableKey, get_system_segment - from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError @@ -37,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: DatasourceNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index bb72fe3881..9c1b7ab2c4 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,17 +2,15 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template - from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( @@ -33,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeIndexNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: - super().__init__(id, config, graph_init_params, graph_runtime_state) + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) self.index_processor = IndexProcessor() self.summary_index_service = SummaryIndex() diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 13624b27b3..6109ced8e2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -9,7 +9,6 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, @@ -51,6 +50,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None: + if value is None or isinstance(value, (str, float)): + return value + if isinstance(value, int) and not isinstance(value, bool): + return value + return str(value) + + +def _normalize_metadata_filter_sequence_item(value: object) -> str: + return value if isinstance(value, str) else str(value) + + class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL @@ -60,13 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeRetrievalNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -283,18 +295,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD resolved_conditions: list[Condition] = [] for cond in conditions.conditions or []: value = cond.value + resolved_value: str | Sequence[str] | int | float | None if isinstance(value, str): segment_group = variable_pool.convert_template(value) if len(segment_group.value) == 1: - resolved_value = segment_group.value[0].to_object() + resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object()) else: resolved_value = segment_group.text elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): - resolved_values = [] - for v in value: # type: ignore + resolved_values: list[str] = [] + for v in value: segment_group = variable_pool.convert_template(v) if len(segment_group.value) == 1: - resolved_values.append(segment_group.value[0].to_object()) + resolved_values.append( + _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object()) + ) else: resolved_values.append(segment_group.text) resolved_value = resolved_values diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index 288d37d265..1d2ad4d445 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -10,8 +10,8 @@ from typing import Any from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from models import ToolFile, UploadFile @@ -135,29 +135,30 @@ def _build_from_local_file( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if row is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + row = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type", "custom"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - reference=build_file_reference(record_id=str(row.id)), - size=row.size, - storage_key=row.key, - ) + return File( + file_id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) def _build_from_remote_url( @@ -179,32 +180,33 @@ def _build_from_remote_url( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if upload_file is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type( - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - ) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - reference=build_file_reference(record_id=str(upload_file.id)), - size=upload_file.size, - storage_key=upload_file.key, - ) + return File( + file_id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) url = mapping.get("url") or mapping.get("remote_url") if not url: @@ -220,9 +222,9 @@ def _build_from_remote_url( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=filename, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -247,30 +249,31 @@ def _build_from_tool_file( ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) - tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") + with session_factory.create_session() as session: + tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - reference=build_file_reference(record_id=str(tool_file.id)), - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) + return File( + file_id=mapping.get("id"), + filename=tool_file.name, + file_type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) def _build_from_datasource_file( @@ -289,31 +292,32 @@ def _build_from_datasource_file( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + with session_factory.create_session() as session: + datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("datasource_file_id"), - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - reference=build_file_reference(record_id=str(datasource_file.id)), - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) + return File( + file_id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + file_type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index b5acbbbcb4..d518114777 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False): def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): - return v.value_type.exposed_type().value + return str(v.value_type.exposed_type()) else: value_type = v.get("value_type") if value_type is None: raise ValueError("value_type is required but not provided") - return value_type.exposed_type().value + return str(value_type.exposed_type()) diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index cf4a71d545..e4219ba1ee 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel): def _normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): try: - return str(SegmentType(value).exposed_type().value) + return str(SegmentType(value).exposed_type()) except ValueError: return value try: diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index e103abee43..3af42eabd8 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.exposed_type().value, + "value_type": str(value.value_type.exposed_type()), "description": value.description, } if isinstance(value, dict): diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 9b53918f24..934aacb45b 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -6,8 +6,8 @@ from flask_login import current_user from pydantic import TypeAdapter from sqlalchemy import select +from core.db.session_factory import session_factory from core.helper.http_client_pooling import get_pooled_http_client -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) @@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def sync_data_source(self, binding_id: str) -> None: # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, + ) ) - ) - if data_source_binding: - # get all authorized pages - pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) - new_source_info = self._build_source_info( - workspace_name=source_info["workspace_name"], - workspace_icon=source_info["workspace_icon"], - workspace_id=source_info["workspace_id"], - pages=pages, - ) - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - raise ValueError("Data source binding not found") + if data_source_binding: + # get all authorized pages + pages = self.get_authorized_pages(data_source_binding.access_token) + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: pages: list[NotionPageSummary] = [] diff --git a/api/models/comment.py b/api/models/comment.py index 7018c7e1f2..4104861d10 100644 --- a/api/models/comment.py +++ b/api/models/comment.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import Optional +import sqlalchemy as sa from sqlalchemy import Index, func from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -36,24 +37,27 @@ class WorkflowComment(Base): __tablename__ = "workflow_comments" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), Index("workflow_comments_app_idx", "tenant_id", "app_id"), Index("workflow_comments_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column( + StringUUID, server_default=sa.text("uuidv7()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - position_x: Mapped[float] = mapped_column(db.Float) - position_y: Mapped[float] = mapped_column(db.Float) - content: Mapped[str] = mapped_column(db.Text, nullable=False) + position_x: Mapped[float] = mapped_column(sa.Float) + position_y: Mapped[float] = mapped_column(sa.Float) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime) + resolved: Mapped[bool] = mapped_column( + sa.Boolean, nullable=False, server_default=sa.text("false")) + resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime) resolved_by: Mapped[str | None] = mapped_column(StringUUID) # Relationships @@ -143,23 +147,26 @@ class WorkflowCommentReply(Base): __tablename__ = "workflow_comment_replies" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), Index("comment_replies_comment_idx", "comment_id"), Index("comment_replies_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column( + StringUUID, server_default=sa.text("uuidv7()")) comment_id: Mapped[str] = mapped_column( - StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) - content: Mapped[str] = mapped_column(db.Text, nullable=False) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) # Relationships - comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + comment: Mapped["WorkflowComment"] = relationship( + "WorkflowComment", back_populates="replies") @property def created_by_account(self): @@ -187,24 +194,26 @@ class WorkflowCommentMention(Base): __tablename__ = "workflow_comment_mentions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), Index("comment_mentions_comment_idx", "comment_id"), Index("comment_mentions_reply_idx", "reply_id"), Index("comment_mentions_user_idx", "mentioned_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column( + StringUUID, server_default=sa.text("uuidv7()")) comment_id: Mapped[str] = mapped_column( - StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) reply_id: Mapped[str | None] = mapped_column( - StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True ) mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - # Relationships - comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") - reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + comment: Mapped["WorkflowComment"] = relationship( + "WorkflowComment", back_populates="mentions") + reply: Mapped[Optional["WorkflowCommentReply"] + ] = relationship("WorkflowCommentReply") @property def mentioned_user_account(self): diff --git a/api/models/human_input.py b/api/models/human_input.py index 79c5d62f6a..7447d3efcb 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from core.workflow.human_input_compat import DeliveryMethodType +from core.workflow.human_input_adapter import DeliveryMethodType +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index a2dc8f6157..77dcbd13d4 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -5,7 +5,8 @@ from functools import lru_cache from typing import Any from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod +from graphon.file import File, FileTransferMethod, FileType +from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object @lru_cache(maxsize=1) @@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id( return tenant_resolver() +def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File: + """Build a graph `File` directly from serialized metadata.""" + + def _coerce_file_type(value: Any) -> FileType: + if isinstance(value, FileType): + return value + if isinstance(value, str): + return FileType.value_of(value) + raise ValueError("file type is required in file mapping") + + mapping = dict(file_mapping) + transfer_method_value = mapping.get("transfer_method") + if isinstance(transfer_method_value, FileTransferMethod): + transfer_method = transfer_method_value + elif isinstance(transfer_method_value, str): + transfer_method = FileTransferMethod.value_of(transfer_method_value) + else: + raise ValueError("transfer_method is required in file mapping") + + file_id = mapping.get("file_id") + if not isinstance(file_id, str) or not file_id: + legacy_id = mapping.get("id") + file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None + + related_id = resolve_file_record_id(mapping) + if related_id is None: + raw_related_id = mapping.get("related_id") + related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None + + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + remote_url = url if isinstance(url, str) and url else None + + reference = mapping.get("reference") + if not isinstance(reference, str) or not reference: + reference = None + + filename = mapping.get("filename") + if not isinstance(filename, str): + filename = None + + extension = mapping.get("extension") + if not isinstance(extension, str): + extension = None + + mime_type = mapping.get("mime_type") + if not isinstance(mime_type, str): + mime_type = None + + size = mapping.get("size", -1) + if not isinstance(size, int): + size = -1 + + storage_key = mapping.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + tenant_id = mapping.get("tenant_id") + if not isinstance(tenant_id, str): + tenant_id = None + + dify_model_identity = mapping.get("dify_model_identity") + if not isinstance(dify_model_identity, str): + dify_model_identity = FILE_MODEL_IDENTITY + + tool_file_id = mapping.get("tool_file_id") + if not isinstance(tool_file_id, str): + tool_file_id = None + + upload_file_id = mapping.get("upload_file_id") + if not isinstance(upload_file_id, str): + upload_file_id = None + + datasource_file_id = mapping.get("datasource_file_id") + if not isinstance(datasource_file_id, str): + datasource_file_id = None + + return File( + file_id=file_id, + tenant_id=tenant_id, + file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))), + transfer_method=transfer_method, + remote_url=remote_url, + reference=reference, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + storage_key=storage_key, + dify_model_identity=dify_model_identity, + url=remote_url, + tool_file_id=tool_file_id, + upload_file_id=upload_file_id, + datasource_file_id=datasource_file_id, + ) + + +def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any: + """Recursively rebuild serialized graph file payloads into `File` objects. + + `graphon` 0.2.2 no longer accepts legacy serialized file mappings via + `model_validate_json()`. Dify keeps this recovery path at the model boundary + so historical JSON blobs remain readable without reintroducing global graph + patches or test-local coercion. + """ + if isinstance(value, list): + return [rebuild_serialized_graph_files_without_lookup(item) for item in value] + + if isinstance(value, dict): + if maybe_file_object(value): + return build_file_from_mapping_without_lookup(file_mapping=value) + return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()} + + return value + + def build_file_from_stored_mapping( *, file_mapping: Mapping[str, Any], @@ -76,12 +195,7 @@ def build_file_from_stored_mapping( pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: - remote_url = mapping.get("remote_url") - if not isinstance(remote_url, str) or not remote_url: - url = mapping.get("url") - if isinstance(url, str) and url: - mapping["remote_url"] = url - return File.model_validate(mapping) + return build_file_from_mapping_without_lookup(file_mapping=mapping) return file_factory.build_from_mapping( mapping=mapping, diff --git a/api/models/workflow.py b/api/models/workflow.py index db2a647dc8..bb4c24380f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -37,7 +37,7 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, @@ -65,7 +65,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID -from .utils.file_input_compat import build_file_from_stored_mapping +from .utils.file_input_compat import ( + build_file_from_mapping_without_lookup, + build_file_from_stored_mapping, +) logger = logging.getLogger(__name__) @@ -339,7 +342,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -1737,7 +1740,7 @@ class WorkflowDraftVariable(Base): return cast(Any, value) normalized_file = dict(value) normalized_file.pop("tenant_id", None) - return File.model_validate(normalized_file) + return build_file_from_mapping_without_lookup(file_mapping=normalized_file) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] @@ -1747,7 +1750,7 @@ class WorkflowDraftVariable(Base): for item in value_list: normalized_file = dict(cast(dict[str, Any], item)) normalized_file.pop("tenant_id", None) - file_list.append(File.model_validate(normalized_file)) + file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/providers/README.md b/api/providers/README.md index a00ec8bc52..5d5e6db9af 100644 --- a/api/providers/README.md +++ b/api/providers/README.md @@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Dify’s API Provider tests often live next to the package, e.g. `providers///tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`). +## Excluding Providers + +In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-` and `--group trace-` to enable specific providers. diff --git a/api/providers/trace/README.md b/api/providers/trace/README.md new file mode 100644 index 0000000000..a7ffa5ed26 --- /dev/null +++ b/api/providers/trace/README.md @@ -0,0 +1,78 @@ +# Trace providers + +This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others). + +Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping. + +## Architecture + +| Layer | Location | Role | +|--------|----------|------| +| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads | +| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class | +| Your package | `api/providers/trace/trace-/` | Pydantic config + subclass of `BaseTraceInstance` | + +At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype. + +## What you implement + +### 1. Config model (`BaseTracingConfig`) + +Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate. + +Fields fall into two groups used by the manager: + +- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords). +- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints). + +List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct. + +### 2. Trace instance (`BaseTraceInstance`) + +Subclass `BaseTraceInstance` and implement: + +```python +def trace(self, trace_info: BaseTraceInfo) -> None: + ... +``` + +Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including: + +- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace` +- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo` +- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo` + +You may ignore categories your backend does not support; existing providers often no-op unhandled types. + +Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context. + +### 3. Register in the API core + +Upstream changes are required so Dify knows your provider exists: + +1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`). +2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning: + - `config_class`: your Pydantic config type + - `secret_keys` / `other_keys`: lists of field names as above + - `trace_instance`: your `BaseTraceInstance` subclass + Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`. + +If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app. + +## Package layout + +Each provider is a normal uv workspace member, for example: + +- `api/providers/trace/trace-/pyproject.toml` — project name `dify-trace-`, dependencies on vendor SDKs +- `api/providers/trace/trace-/src/dify_trace_/` — `config.py`, `_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`. + +Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`. + +## Wiring into the `api` workspace + +In `api/pyproject.toml`: + +1. **`[tool.uv.sources]`** — `dify-trace- = { workspace = true }` +2. **`[dependency-groups]`** — add `trace- = ["dify-trace-"]` and include `dify-trace-` in `trace-all` if it should ship with the default bundle + +After changing metadata, run **`uv sync`** from `api/`. diff --git a/api/providers/trace/trace-aliyun/pyproject.toml b/api/providers/trace/trace-aliyun/pyproject.toml new file mode 100644 index 0000000000..bcef7e9fb1 --- /dev/null +++ b/api/providers/trace/trace-aliyun/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-aliyun" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Aliyun)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py similarity index 98% rename from api/core/ops/aliyun_trace/aliyun_trace.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py index 70aaf2a07b..eb44d437f9 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py @@ -6,7 +6,20 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.data_exporter.traceclient import ( TraceClient, build_endpoint, convert_datetime_to_nanoseconds, @@ -14,8 +27,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( convert_to_trace_id, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata +from dify_trace_aliyun.entities.semconv import ( DIFY_APP_ID, GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, @@ -34,7 +47,7 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -46,19 +59,6 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) -from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import AliyunConfig -from core.ops.entities.trace_entity import ( - BaseTraceInfo, - DatasetRetrievalTraceInfo, - GenerateNameTraceInfo, - MessageTraceInfo, - ModerationTraceInfo, - SuggestedQuestionTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from models import WorkflowNodeExecutionTriggeredFrom diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py new file mode 100644 index 0000000000..e0133e6cc9 --- /dev/null +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py @@ -0,0 +1,32 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class AliyunConfig(BaseTracingConfig): + """ + Model class for Aliyun tracing config. + """ + + app_name: str = "dify_app" + license_key: str + endpoint: str + + @field_validator("app_name") + @classmethod + def app_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") + + @field_validator("license_key") + @classmethod + def license_key_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("License key cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # aliyun uses two URL formats, which may include a URL path + return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/data_exporter/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py similarity index 98% rename from api/core/ops/aliyun_trace/data_exporter/traceclient.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py index 67d5163b0f..00aab6bf89 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py @@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData -from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE INVALID_SPAN_ID: Final[int] = 0x0000000000000000 INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000 diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/semconv.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed similarity index 100% rename from api/core/ops/arize_phoenix_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed diff --git a/api/core/ops/aliyun_trace/utils.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py similarity index 97% rename from api/core/ops/aliyun_trace/utils.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py index aa35ac74c2..d446f1b5c3 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py @@ -6,7 +6,8 @@ from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode -from core.ops.aliyun_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -15,7 +16,6 @@ from core.ops.aliyun_trace.entities.semconv import ( OUTPUT_VALUE, GenAISpanKind, ) -from core.rag.models.document import Document from extensions.ext_database import db from models import EndUser @@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status: def create_links_from_trace_id(trace_id: str | None) -> list[Link]: - from core.ops.aliyun_trace.data_exporter.traceclient import create_link + from dify_trace_aliyun.data_exporter.traceclient import create_link links = [] if trace_id: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py similarity index 86% rename from api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index acb43d4036..286dda419c 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.trace import SpanKind, Status, StatusCode - -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from dify_trace_aliyun.data_exporter.traceclient import ( INVALID_SPAN_ID, SpanBuilder, TraceClient, @@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( create_link, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode @pytest.fixture @@ -41,8 +40,8 @@ def trace_client_factory(): class TestTraceClient: - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname") def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): mock_gethostname.return_value = "test-host" client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -56,7 +55,7 @@ class TestTraceClient: client.shutdown() assert client.done is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -64,8 +63,8 @@ class TestTraceClient: client.export(spans) mock_exporter.export.assert_called_once_with(spans) - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 405 @@ -74,8 +73,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 500 @@ -84,8 +83,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is False - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): mock_head.side_effect = httpx.RequestError("Connection error") @@ -93,12 +92,12 @@ class TestTraceClient: with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): client.api_check() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_get_project_url(self, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_add_span(self, mock_exporter_class, trace_client_factory): client = trace_client_factory( service_name="test-service", @@ -134,8 +133,8 @@ class TestTraceClient: assert len(client.queue) == 2 mock_notify.assert_called_once() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.logger") def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) @@ -159,7 +158,7 @@ class TestTraceClient: assert len(client.queue) == 1 mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export_batch_error(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value mock_exporter.export.side_effect = Exception("Export failed") @@ -168,11 +167,11 @@ class TestTraceClient: mock_span = MagicMock(spec=ReadableSpan) client.queue.append(mock_span) - with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger: client._export_batch() mock_logger.warning.assert_called() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_worker_loop(self, mock_exporter_class, trace_client_factory): # We need to test the wait timeout in _worker # But _worker runs in a thread. Let's mock condition.wait. @@ -189,7 +188,7 @@ class TestTraceClient: # mock_wait might have been called assert mock_wait.called or client.done - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -268,7 +267,7 @@ def test_generate_span_id(): assert span_id != INVALID_SPAN_ID # Test retry loop - with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand: mock_rand.side_effect = [INVALID_SPAN_ID, 999] span_id = generate_span_id() assert span_id == 999 @@ -290,7 +289,7 @@ def test_convert_to_trace_id(): def test_convert_string_to_id(): assert convert_string_to_id("test") > 0 # Test with None string - with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen: mock_gen.return_value = 12345 assert convert_string_to_id(None) == 12345 diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py index 2fcb927e0c..38d33dd21b 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -1,11 +1,10 @@ import pytest +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event from opentelemetry.trace import SpanKind, Status, StatusCode from pydantic import ValidationError -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata - class TestTraceMetadata: def test_trace_metadata_init(self): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py index 3961555b9a..9cab40748f 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py @@ -1,4 +1,4 @@ -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( ACS_ARMS_SERVICE_FEATURE, GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py similarity index 99% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py index c2324fdec4..c1b11c9186 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py @@ -4,12 +4,11 @@ from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import MagicMock +import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import pytest -from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags - -import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module -from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.aliyun_trace import AliyunDataTrace +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, GEN_AI_OUTPUT_MESSAGE, @@ -24,7 +23,8 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.entities.config_entity import AliyunConfig +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py similarity index 95% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py index e4d8f2d5ea..a9e7b80c2a 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py @@ -1,9 +1,7 @@ import json from unittest.mock import MagicMock -from opentelemetry.trace import Link, StatusCode - -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -11,7 +9,7 @@ from core.ops.aliyun_trace.entities.semconv import ( INPUT_VALUE, OUTPUT_VALUE, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -23,6 +21,8 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) +from opentelemetry.trace import Link, StatusCode + from core.rag.models.document import Document from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionStatus @@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = end_user_data - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = None - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -112,9 +112,9 @@ def test_get_workflow_node_status(): def test_create_links_from_trace_id(monkeypatch): # Mock create_link mock_link = MagicMock(spec=Link) - import core.ops.aliyun_trace.data_exporter.traceclient + import dify_trace_aliyun.data_exporter.traceclient - monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) # Trace ID None assert create_links_from_trace_id(None) == [] diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..1b24ee7421 --- /dev/null +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py @@ -0,0 +1,85 @@ +import pytest +from dify_trace_aliyun.config import AliyunConfig +from pydantic import ValidationError + + +class TestAliyunConfig: + """Test cases for AliyunConfig""" + + def test_valid_config(self): + """Test valid Aliyun configuration""" + config = AliyunConfig( + app_name="test_app", + license_key="test_license_key", + endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", + ) + assert config.app_name == "test_app" + assert config.license_key == "test_license_key" + assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + assert config.app_name == "dify_app" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + AliyunConfig() + + with pytest.raises(ValidationError): + AliyunConfig(license_key="test_license") + + with pytest.raises(ValidationError): + AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_app_name_validation_empty(self): + """Test app_name validation with empty value""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" + ) + assert config.app_name == "dify_app" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = AliyunConfig(license_key="test_license", endpoint="") + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation preserves path for Aliyun endpoints""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + ) + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_license_key_required(self): + """Test that license_key is required and cannot be empty""" + with pytest.raises(ValidationError): + AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_valid_endpoint_format_examples(self): + """Test valid endpoint format examples from comments""" + valid_endpoints = [ + # cms2.0 public endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", + # cms2.0 intranet endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", + # xtrace public endpoint + "http://tracing-cn-heyuan.arms.aliyuncs.com", + # xtrace intranet endpoint + "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", + ] + + for endpoint in valid_endpoints: + config = AliyunConfig(license_key="test_license", endpoint=endpoint) + assert config.endpoint == endpoint diff --git a/api/providers/trace/trace-arize-phoenix/pyproject.toml b/api/providers/trace/trace-arize-phoenix/pyproject.toml new file mode 100644 index 0000000000..9e756944c9 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-arize-phoenix" +version = "0.0.1" +dependencies = [ + "arize-phoenix-otel~=0.15.0", +] +description = "Dify ops tracing provider (Arize / Phoenix)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langfuse_trace/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py similarity index 100% rename from api/core/ops/langfuse_trace/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py similarity index 99% rename from api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index dd5edde630..037023e771 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -26,7 +26,6 @@ from opentelemetry.util.types import AttributeValue from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -40,6 +39,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig from extensions.ext_database import db from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py new file mode 100644 index 0000000000..6eac5b30d2 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py @@ -0,0 +1,45 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class ArizeConfig(BaseTracingConfig): + """ + Model class for Arize tracing config. + """ + + api_key: str | None = None + space_id: str | None = None + project: str | None = None + endpoint: str = "https://otlp.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://otlp.arize.com") + + +class PhoenixConfig(BaseTracingConfig): + """ + Model class for Phoenix tracing config. + """ + + api_key: str | None = None + project: str | None = None + endpoint: str = "https://app.phoenix.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://app.phoenix.arize.com") diff --git a/api/core/ops/langfuse_trace/entities/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed similarity index 100% rename from api/core/ops/langfuse_trace/entities/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index 4ce9e22fd7..b0691a87ea 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -2,11 +2,7 @@ from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, patch import pytest -from opentelemetry.sdk.trace import Tracer -from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes -from opentelemetry.trace import StatusCode - -from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( +from dify_trace_arize_phoenix.arize_phoenix_trace import ( ArizePhoenixDataTrace, datetime_to_nanos, error_to_string, @@ -15,7 +11,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( setup_tracer, wrap_span_metadata, ) -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -80,7 +80,7 @@ def test_datetime_to_nanos(): expected = int(dt.timestamp() * 1_000_000_000) assert datetime_to_nanos(dt) == expected - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt: mock_now = MagicMock() mock_now.timestamp.return_value = 1704110400.0 mock_dt.now.return_value = mock_now @@ -142,8 +142,8 @@ def test_wrap_span_metadata(): assert res == {"a": 1, "b": 2, "created_from": "Dify"} -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_arize(mock_provider, mock_exporter): config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") setup_tracer(config) @@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter): assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_phoenix(mock_provider, mock_exporter): config = PhoenixConfig(endpoint="http://p.com", project="p") setup_tracer(config) @@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter): def test_setup_tracer_exception(): config = ArizeConfig(endpoint="http://a.com", project="p") - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): with pytest.raises(Exception, match="boom"): setup_tracer(config) @@ -172,7 +172,7 @@ def test_setup_tracer_exception(): @pytest.fixture def trace_instance(): - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup: mock_tracer = MagicMock(spec=Tracer) mock_processor = MagicMock() mock_setup.return_value = (mock_tracer, mock_processor) @@ -228,9 +228,9 @@ def test_trace_exception(trace_instance): trace_instance.trace(_make_workflow_info()) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac assert trace_instance.tracer.start_span.call_count >= 2 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_no_app_id(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance): trace_instance.workflow_trace(info) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_success(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() @@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance): assert trace_instance.tracer.start_span.call_count >= 1 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_with_error(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py index 4b925390d9..a01c63ae61 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py @@ -1,6 +1,6 @@ +from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from openinference.semconv.trace import OpenInferenceSpanKindValues -from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..11e951c3b1 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py @@ -0,0 +1,88 @@ +import pytest +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from pydantic import ValidationError + + +class TestArizeConfig: + """Test cases for ArizeConfig""" + + def test_valid_config(self): + """Test valid Arize configuration""" + config = ArizeConfig( + api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" + ) + assert config.api_key == "test_key" + assert config.space_id == "test_space" + assert config.project == "test_project" + assert config.endpoint == "https://custom.arize.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = ArizeConfig() + assert config.api_key is None + assert config.space_id is None + assert config.project is None + assert config.endpoint == "https://otlp.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = ArizeConfig(project="") + assert config.project == "default" + + def test_project_validation_none(self): + """Test project validation with None value""" + config = ArizeConfig(project=None) + assert config.project == "default" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = ArizeConfig(endpoint="") + assert config.endpoint == "https://otlp.arize.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation normalizes URL by removing path""" + config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") + assert config.endpoint == "https://custom.arize.com" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="ftp://invalid.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="invalid.com") + + +class TestPhoenixConfig: + """Test cases for PhoenixConfig""" + + def test_valid_config(self): + """Test valid Phoenix configuration""" + config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.phoenix.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = PhoenixConfig() + assert config.api_key is None + assert config.project is None + assert config.endpoint == "https://app.phoenix.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = PhoenixConfig(project="") + assert config.project == "default" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" diff --git a/api/providers/trace/trace-langfuse/pyproject.toml b/api/providers/trace/trace-langfuse/pyproject.toml new file mode 100644 index 0000000000..27d2273a69 --- /dev/null +++ b/api/providers/trace/trace-langfuse/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langfuse" +version = "0.0.1" +dependencies = [ + "langfuse>=4.2.0,<5.0.0", +] +description = "Dify ops tracing provider (Langfuse)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langsmith_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py new file mode 100644 index 0000000000..90d1a2846b --- /dev/null +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py @@ -0,0 +1,19 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class LangfuseConfig(BaseTracingConfig): + """ + Model class for Langfuse tracing config. + """ + + public_key: str + secret_key: str + host: str = "https://api.langfuse.com" + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://api.langfuse.com") diff --git a/api/core/ops/langsmith_trace/entities/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py similarity index 100% rename from api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py similarity index 99% rename from api/core/ops/langfuse_trace/langfuse_trace.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py index 7eacc2be46..68881378a7 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py @@ -16,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -28,7 +27,10 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( +from core.ops.utils import filter_none_values +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( GenerationUsage, LangfuseGeneration, LangfuseSpan, @@ -36,8 +38,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( LevelEnum, UnitEnum, ) -from core.ops.utils import filter_none_values -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/mlflow_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed similarity index 100% rename from api/core/ops/mlflow_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py index a0bcc92795..952f10c34f 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py @@ -5,8 +5,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -17,14 +25,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( - LangfuseGeneration, - LangfuseSpan, - LangfuseTrace, - LevelEnum, - UnitEnum, -) -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -43,7 +43,7 @@ def langfuse_config(): def trace_instance(langfuse_config, monkeypatch): # Mock Langfuse client to avoid network calls mock_client = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client) instance = LangFuseDataTrace(langfuse_config) return instance @@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch): def test_init(langfuse_config, monkeypatch): mock_langfuse = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangFuseDataTrace(langfuse_config) @@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): # Mock DB and Repositories mock_session = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="log-1", error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() @@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..103d888eef --- /dev/null +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py @@ -0,0 +1,42 @@ +import pytest +from dify_trace_langfuse.config import LangfuseConfig +from pydantic import ValidationError + + +class TestLangfuseConfig: + """Test cases for LangfuseConfig""" + + def test_valid_config(self): + """Test valid Langfuse configuration""" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == "https://custom.langfuse.com" + + def test_valid_config_with_path(self): + host = "https://custom.langfuse.com/api/v1" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == host + + def test_default_values(self): + """Test default values are set correctly""" + config = LangfuseConfig(public_key="public", secret_key="secret") + assert config.host == "https://api.langfuse.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangfuseConfig() + + with pytest.raises(ValidationError): + LangfuseConfig(public_key="public") + + with pytest.raises(ValidationError): + LangfuseConfig(secret_key="secret") + + def test_host_validation_empty(self): + """Test host validation with empty value""" + config = LangfuseConfig(public_key="public", secret_key="secret", host="") + assert config.host == "https://api.langfuse.com" diff --git a/api/tests/unit_tests/core/ops/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py similarity index 92% rename from api/tests/unit_tests/core/ops/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py index 017ac8c891..0340ffb669 100644 --- a/api/tests/unit_tests/core/ops/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py @@ -4,14 +4,15 @@ from datetime import datetime, timedelta from types import SimpleNamespace from unittest.mock import MagicMock, patch -from core.ops.entities.config_entity import LangfuseConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace + from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from graphon.enums import BuiltinNodeTypes def _create_trace_instance() -> LangFuseDataTrace: - with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True): + with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True): return LangFuseDataTrace( LangfuseConfig( public_key="public-key", @@ -116,9 +117,9 @@ class TestLangFuseDataTraceCompletionStartTime: patch.object(trace, "add_span"), patch.object(trace, "add_generation") as add_generation, patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()), - patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()), + patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()), patch( - "core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=repository, ), ): diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml new file mode 100644 index 0000000000..8131952b28 --- /dev/null +++ b/api/providers/trace/trace-langsmith/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langsmith" +version = "0.0.1" +dependencies = [ + "langsmith~=0.7.30", +] +description = "Dify ops tracing provider (LangSmith)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/opik_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py similarity index 100% rename from api/core/ops/opik_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py new file mode 100644 index 0000000000..498b8c5e7e --- /dev/null +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py @@ -0,0 +1,20 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class LangSmithConfig(BaseTracingConfig): + """ + Model class for Langsmith tracing config. + """ + + api_key: str + project: str + endpoint: str = "https://api.smith.langchain.com" + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # LangSmith only allows HTTPS + return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) diff --git a/api/core/ops/tencent_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py similarity index 99% rename from api/core/ops/langsmith_trace/langsmith_trace.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py index 490c64af84..504673d6ef 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py @@ -10,7 +10,6 @@ from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -22,13 +21,14 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( +from core.ops.utils import filter_none_values, generate_dotted_order +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( LangSmithRunModel, LangSmithRunType, LangSmithRunUpdateModel, ) -from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/weave_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed similarity index 100% rename from api/core/ops/weave_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py rename to api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index 34c64c54a1..45e5894e4a 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -3,8 +3,14 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -15,12 +21,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( - LangSmithRunModel, - LangSmithRunType, - LangSmithRunUpdateModel, -) -from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -38,7 +38,7 @@ def langsmith_config(): def trace_instance(langsmith_config, monkeypatch): # Mock LangSmith client mock_client = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client) instance = LangSmithDataTrace(langsmith_config) return instance @@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch): def test_init(langsmith_config, monkeypatch): mock_client_class = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangSmithDataTrace(langsmith_config) @@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch): # Mock dependencies mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() @@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): trace_info.error = "" mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() @@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..37efaf69cf --- /dev/null +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py @@ -0,0 +1,35 @@ +import pytest +from dify_trace_langsmith.config import LangSmithConfig +from pydantic import ValidationError + + +class TestLangSmithConfig: + """Test cases for LangSmithConfig""" + + def test_valid_config(self): + """Test valid LangSmith configuration""" + config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.smith.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = LangSmithConfig(api_key="key", project="project") + assert config.endpoint == "https://api.smith.langchain.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangSmithConfig() + + with pytest.raises(ValidationError): + LangSmithConfig(api_key="key") + + with pytest.raises(ValidationError): + LangSmithConfig(project="project") + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") diff --git a/api/providers/trace/trace-mlflow/pyproject.toml b/api/providers/trace/trace-mlflow/pyproject.toml new file mode 100644 index 0000000000..fad6002944 --- /dev/null +++ b/api/providers/trace/trace-mlflow/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-mlflow" +version = "0.0.1" +dependencies = [ + "mlflow-skinny>=3.11.1", +] +description = "Dify ops tracing provider (MLflow / Databricks)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py similarity index 100% rename from api/core/ops/weave_trace/entities/__init__.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py new file mode 100644 index 0000000000..84914165e3 --- /dev/null +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py @@ -0,0 +1,46 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_integer_id, validate_url_with_path + + +class MLflowConfig(BaseTracingConfig): + """ + Model class for MLflow tracing config. + """ + + tracking_uri: str = "http://localhost:5000" + experiment_id: str = "0" # Default experiment id in MLflow is 0 + username: str | None = None + password: str | None = None + + @field_validator("tracking_uri") + @classmethod + def tracking_uri_validator(cls, v, info: ValidationInfo): + if isinstance(v, str) and v.startswith("databricks"): + raise ValueError( + "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." + ) + return validate_url_with_path(v, "http://localhost:5000") + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) + + +class DatabricksConfig(BaseTracingConfig): + """ + Model class for Databricks (Databricks-managed MLflow) tracing config. + """ + + experiment_id: str + host: str + client_id: str | None = None + client_secret: str | None = None + personal_access_token: str | None = None + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py similarity index 99% rename from api/core/ops/mlflow_trace/mlflow_trace.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py index c070a937be..37c9ecda25 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py @@ -12,7 +12,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -25,6 +24,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import JSON_DICT_ADAPTER +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig from extensions.ext_database import db from models import EndUser from models.workflow import WorkflowNodeExecutionModel diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py similarity index 98% rename from api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py rename to api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index afc5726ede..20211456e3 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" +"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module.""" from __future__ import annotations @@ -9,8 +9,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig +from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -20,7 +21,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -179,7 +179,7 @@ def _make_node(**overrides): @pytest.fixture def mock_mlflow(): - with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock: yield mock @@ -187,10 +187,10 @@ def mock_mlflow(): def mock_tracing(): """Patch all MLflow tracing functions used by the module.""" with ( - patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, - patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, - patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, - patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start, + patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update, + patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set, + patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach, ): yield { "start": mock_start, @@ -202,7 +202,7 @@ def mock_tracing(): @pytest.fixture def mock_db(): - with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + with patch("dify_trace_mlflow.mlflow_trace.db") as mock: yield mock diff --git a/api/providers/trace/trace-opik/pyproject.toml b/api/providers/trace/trace-opik/pyproject.toml new file mode 100644 index 0000000000..874997168e --- /dev/null +++ b/api/providers/trace/trace-opik/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-opik" +version = "0.0.1" +dependencies = [ + "opik~=1.11.2", +] +description = "Dify ops tracing provider (Opik)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/config.py b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py new file mode 100644 index 0000000000..c16ff1d903 --- /dev/null +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py @@ -0,0 +1,25 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class OpikConfig(BaseTracingConfig): + """ + Model class for Opik tracing config. + """ + + api_key: str | None = None + project: str | None = None + workspace: str | None = None + url: str = "https://www.comet.com/opik/api/" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "Default Project") + + @field_validator("url") + @classmethod + def url_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py similarity index 99% rename from api/core/ops/opik_trace/opik_trace.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py index e0c7b9bfe5..fd2d0ca097 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py @@ -11,7 +11,6 @@ from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -24,6 +23,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory +from dify_trace_opik.config import OpikConfig from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed b/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py index c02ac413f2..eefed3c78c 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py @@ -5,8 +5,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_opik.config import OpikConfig +from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -17,7 +18,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -37,7 +37,7 @@ def opik_config(): @pytest.fixture def trace_instance(opik_config, monkeypatch): mock_client = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client) instance = OpikDataTrace(opik_config) return instance @@ -67,7 +67,7 @@ def test_prepare_opik_uuid(): def test_init(opik_config, monkeypatch): mock_opik = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik) monkeypatch.setenv("FILES_URL", "http://test.url") instance = OpikDataTrace(opik_config) @@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) node_llm = MagicMock() node_llm.id = LLM_NODE_ID @@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() @@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..5a54b70bba --- /dev/null +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py @@ -0,0 +1,48 @@ +import pytest +from dify_trace_opik.config import OpikConfig +from pydantic import ValidationError + + +class TestOpikConfig: + """Test cases for OpikConfig""" + + def test_valid_config(self): + """Test valid Opik configuration""" + config = OpikConfig( + api_key="test_key", + project="test_project", + workspace="test_workspace", + url="https://custom.comet.com/opik/api/", + ) + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.workspace == "test_workspace" + assert config.url == "https://custom.comet.com/opik/api/" + + def test_default_values(self): + """Test default values are set correctly""" + config = OpikConfig() + assert config.api_key is None + assert config.project is None + assert config.workspace is None + assert config.url == "https://www.comet.com/opik/api/" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = OpikConfig(project="") + assert config.project == "Default Project" + + def test_url_validation_empty(self): + """Test URL validation with empty value""" + config = OpikConfig(url="") + assert config.url == "https://www.comet.com/opik/api/" + + def test_url_validation_missing_suffix(self): + """Test URL validation requires /api/ suffix""" + with pytest.raises(ValidationError, match="URL should end with /api/"): + OpikConfig(url="https://custom.comet.com/opik/") + + def test_url_validation_invalid_scheme(self): + """Test URL validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + OpikConfig(url="ftp://custom.comet.com/opik/api/") diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py index ad9d0846be..fba290f5b8 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py @@ -14,8 +14,9 @@ import uuid from datetime import datetime from unittest.mock import MagicMock, patch +from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid + from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo -from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid # A stable UUID4 used as the workflow_run_id throughout all tests. _WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6" @@ -56,8 +57,8 @@ def _make_workflow_trace_info( def _make_opik_trace_instance() -> OpikDataTrace: """Construct an OpikDataTrace with the Opik SDK client mocked out.""" - with patch("core.ops.opik_trace.opik_trace.Opik"): - from core.ops.entities.config_entity import OpikConfig + with patch("dify_trace_opik.opik_trace.Opik"): + from dify_trace_opik.config import OpikConfig config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/") instance = OpikDataTrace(config) @@ -133,10 +134,10 @@ class TestWorkflowTraceWithoutMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): @@ -265,10 +266,10 @@ class TestWorkflowTraceWithMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): diff --git a/api/providers/trace/trace-tencent/pyproject.toml b/api/providers/trace/trace-tencent/pyproject.toml new file mode 100644 index 0000000000..eab06fc708 --- /dev/null +++ b/api/providers/trace/trace-tencent/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-tencent" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Tencent APM)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/tencent_trace/client.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py similarity index 100% rename from api/core/ops/tencent_trace/client.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py new file mode 100644 index 0000000000..398e6c55a8 --- /dev/null +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py @@ -0,0 +1,30 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig + + +class TencentConfig(BaseTracingConfig): + """ + Tencent APM tracing config + """ + + token: str + endpoint: str + service_name: str + + @field_validator("token") + @classmethod + def token_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("Token cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") + + @field_validator("service_name") + @classmethod + def service_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") diff --git a/api/core/ops/tencent_trace/entities/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/entities/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py diff --git a/api/core/ops/tencent_trace/entities/semconv.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py similarity index 100% rename from api/core/ops/tencent_trace/entities/semconv.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py diff --git a/api/core/ops/tencent_trace/entities/tencent_trace_entity.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py similarity index 100% rename from api/core/ops/tencent_trace/entities/tencent_trace_entity.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed b/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py similarity index 98% rename from api/core/ops/tencent_trace/span_builder.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py index f79095d966..763a85ffd7 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py @@ -6,8 +6,6 @@ import json import logging from datetime import datetime -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -16,7 +14,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_tencent.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, GEN_AI_IS_ENTRY, @@ -40,9 +39,10 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.utils import TencentTraceUtils -from core.rag.models.document import Document +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.utils import TencentTraceUtils +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py similarity index 94% rename from api/core/ops/tencent_trace/tencent_trace.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py index 84f54d8a5a..3b9e24573a 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py @@ -1,7 +1,6 @@ -""" -Tencent APM tracing implementation with separated concerns -""" +"""Tencent APM tracing with idempotent client cleanup.""" +import inspect import logging from graphon.entities.workflow_node_execution import ( @@ -12,7 +11,6 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -23,11 +21,12 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.client import TencentTraceClient -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.span_builder import TencentSpanBuilder -from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from dify_trace_tencent.client import TencentTraceClient +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.span_builder import TencentSpanBuilder +from dify_trace_tencent.utils import TencentTraceUtils from extensions.ext_database import db from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom @@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance): """ Tencent APM trace implementation with single responsibility principle. Acts as a coordinator that delegates specific tasks to specialized classes. + + The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen + explicitly in tests or implicitly during garbage collection, so shutdown + must be safe to call multiple times. """ + trace_client: TencentTraceClient + _closed: bool + def __init__(self, tencent_config: TencentConfig): super().__init__(tencent_config) + self._closed = False self.trace_client = TencentTraceClient( service_name=tencent_config.service_name, endpoint=tencent_config.endpoint, @@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance): except Exception: logger.debug("[Tencent APM] Failed to record message trace duration") - def __del__(self): - """Ensure proper cleanup on garbage collection.""" + def close(self) -> None: + """Synchronously and idempotently shutdown the underlying trace client.""" + if getattr(self, "_closed", False): + return + + self._closed = True + trace_client = getattr(self, "trace_client", None) + if trace_client is None: + return + try: - if hasattr(self, "trace_client"): - self.trace_client.shutdown() + shutdown_result = trace_client.shutdown() + if inspect.isawaitable(shutdown_result): + close_awaitable = getattr(shutdown_result, "close", None) + if callable(close_awaitable): + close_awaitable() except Exception: logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup") + + def __del__(self): + """Ensure best-effort cleanup on garbage collection without retrying shutdown.""" + self.close() diff --git a/api/core/ops/tencent_trace/utils.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py similarity index 100% rename from api/core/ops/tencent_trace/utils.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py similarity index 98% rename from api/tests/unit_tests/core/ops/tencent_trace/test_client.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py index 870c18e53e..1e656e2462 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py @@ -8,13 +8,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_tencent import client as client_module +from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version +from dify_trace_tencent.entities.tencent_trace_entity import SpanData from opentelemetry.sdk.trace import Event from opentelemetry.trace import Status, StatusCode -from core.ops.tencent_trace import client as client_module -from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData - metric_reader_instances: list[DummyMetricReader] = [] meter_provider_instances: list[DummyMeterProvider] = [] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py similarity index 89% rename from api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py index 6113e5c6c8..e850a801f3 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py @@ -1,15 +1,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch -from opentelemetry.trace import StatusCode - -from core.ops.entities.trace_entity import ( - DatasetRetrievalTraceInfo, - MessageTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.ops.tencent_trace.entities.semconv import ( +from dify_trace_tencent.entities.semconv import ( GEN_AI_IS_ENTRY, GEN_AI_IS_STREAMING_REQUEST, GEN_AI_MODEL_NAME, @@ -23,7 +15,15 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from dify_trace_tencent.span_builder import TencentSpanBuilder +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) from core.rag.models.document import Document from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -31,7 +31,7 @@ from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutio class TestTencentSpanBuilder: def test_get_time_nanoseconds(self): - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: mock_convert.return_value = 123456789 dt = datetime.now() result = TencentSpanBuilder._get_time_nanoseconds(dt) @@ -48,7 +48,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {"answer": "world"} trace_info.metadata = {"conversation_id": "conv_id"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -70,7 +70,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {} trace_info.metadata = {} # No conversation_id - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 1 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -98,7 +98,7 @@ class TestTencentSpanBuilder: } node_execution.outputs = {"text": "world"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -123,7 +123,7 @@ class TestTencentSpanBuilder: "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, } - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -142,7 +142,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {"conversation_id": "conv_id"} trace_info.is_streaming_request = True - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -162,7 +162,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {} trace_info.is_streaming_request = False - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -182,7 +182,7 @@ class TestTencentSpanBuilder: trace_info.tool_inputs = {"i": 2} trace_info.tool_outputs = "result" - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 101 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) @@ -204,7 +204,7 @@ class TestTencentSpanBuilder: ) trace_info.documents = [doc] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -222,7 +222,7 @@ class TestTencentSpanBuilder: trace_info.end_time = datetime.now() trace_info.documents = [] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -264,7 +264,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -286,7 +286,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -307,7 +307,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -329,7 +329,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -350,7 +350,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 505 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py similarity index 86% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index 7afd0b824a..54524b09ca 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -1,9 +1,12 @@ +import gc import logging -from unittest.mock import MagicMock, patch +import warnings +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.tencent_trace import TencentDataTrace -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -13,7 +16,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.tencent_trace import TencentDataTrace from graphon.entities import WorkflowNodeExecution from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin @@ -28,19 +30,19 @@ def tencent_config(): @pytest.fixture def mock_trace_client(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock: yield mock @pytest.fixture def mock_span_builder(): - with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock: yield mock @pytest.fixture def mock_trace_utils(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock: yield mock @@ -198,9 +200,9 @@ class TestTencentDataTrace: trace_info.workflow_run_id = "run-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.workflow_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") @@ -230,9 +232,9 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.message_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") @@ -262,9 +264,9 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.tool_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") @@ -294,22 +296,22 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.dataset_retrieval_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") def test_suggested_question_trace(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") def test_suggested_question_trace_exception(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") @@ -342,7 +344,7 @@ class TestTencentDataTrace: with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) # The exception should be caught by the outer handler since convert_to_span_id is called first mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -351,7 +353,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -381,7 +383,7 @@ class TestTencentDataTrace: node.id = "n1" mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) assert result is None mock_log.assert_called_once() @@ -403,15 +405,13 @@ class TestTencentDataTrace: mock_executions = [MagicMock()] - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.side_effect = [app, account, tenant_join] - with patch( - "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" - ) as mock_repo: + with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo: mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) @@ -423,7 +423,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {} - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -432,14 +432,14 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {"app_id": "app-1"} - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() # Ensure init_app is mocked mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.return_value = None - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -449,8 +449,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "tenant-1" trace_info.metadata = {"user_id": "user-1"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() mock_db.engine = MagicMock() @@ -476,8 +476,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "t" trace_info.metadata = {"user_id": "u"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "unknown" mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") @@ -519,7 +519,7 @@ class TestTencentDataTrace: node.process_data = None node.outputs = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_llm_metrics(node) # Should not crash @@ -557,7 +557,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.metadata = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_llm_metrics(trace_info) # Should not crash @@ -609,7 +609,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_workflow_trace_duration(trace_info) def test_record_message_trace_duration(self, tencent_data_trace): @@ -631,16 +631,41 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.start_time = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_trace_duration(trace_info) - def test_del(self, tencent_data_trace): + def test_close(self, tencent_data_trace): client = tencent_data_trace.trace_client - tencent_data_trace.__del__() + tencent_data_trace.close() client.shutdown.assert_called_once() - def test_del_exception(self, tencent_data_trace): + def test_close_is_idempotent(self, tencent_data_trace): + client = tencent_data_trace.trace_client + + tencent_data_trace.close() + tencent_data_trace.close() + + client.shutdown.assert_called_once() + + def test_close_exception(self, tencent_data_trace): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.__del__() + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.close() mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") + + def test_close_handles_async_shutdown_mock(self, tencent_data_trace): + shutdown = AsyncMock() + tencent_data_trace.trace_client.shutdown = shutdown + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + tencent_data_trace.close() + gc.collect() + + shutdown.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) + and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py similarity index 88% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py index ef28d18e20..63c6d680d7 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py @@ -8,10 +8,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from dify_trace_tencent.utils import TencentTraceUtils from opentelemetry.trace import Link, TraceFlags -from core.ops.tencent_trace.utils import TencentTraceUtils - def test_convert_to_trace_id_with_valid_uuid() -> None: uuid_str = "12345678-1234-5678-1234-567812345678" @@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None: def test_convert_to_trace_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int uuid4_mock.assert_called_once() @@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: def test_convert_to_span_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") assert isinstance(span_id, int) uuid4_mock.assert_called_once() @@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: def test_generate_span_id_skips_invalid_span_id() -> None: with patch( - "core.ops.tencent_trace.utils.random.getrandbits", + "dify_trace_tencent.utils.random.getrandbits", side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], ) as bits_mock: assert TencentTraceUtils.generate_span_id() == 42 @@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) expected = int(fixed.timestamp() * 1e9) - with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + with patch("dify_trace_tencent.utils.datetime") as datetime_mock: datetime_mock.now.return_value = fixed assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected datetime_mock.now.assert_called_once() @@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i @pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] assert link.context.trace_id == fallback_uuid.int uuid4_mock.assert_called_once() diff --git a/api/providers/trace/trace-weave/pyproject.toml b/api/providers/trace/trace-weave/pyproject.toml new file mode 100644 index 0000000000..ba449f2a93 --- /dev/null +++ b/api/providers/trace/trace-weave/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-weave" +version = "0.0.1" +dependencies = [ + "weave>=0.52.36", +] +description = "Dify ops tracing provider (Weave)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/config.py b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py new file mode 100644 index 0000000000..5942bd57fe --- /dev/null +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py @@ -0,0 +1,29 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class WeaveConfig(BaseTracingConfig): + """ + Model class for Weave tracing config. + """ + + api_key: str + entity: str | None = None + project: str + endpoint: str = "https://trace.wandb.ai" + host: str | None = None + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # Weave only allows HTTPS for endpoint + return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + if v is not None and v.strip() != "": + return validate_url(v, v, allowed_schemes=("https", "http")) + return v diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py similarity index 100% rename from api/core/ops/weave_trace/entities/weave_trace_entity.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed b/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py similarity index 99% rename from api/core/ops/weave_trace/weave_trace.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py index 8d9ba4694d..fd9b965ba0 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py @@ -18,7 +18,6 @@ from weave.trace_server.trace_server_interface import ( ) from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -30,8 +29,9 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..eeb1fe1d87 --- /dev/null +++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py @@ -0,0 +1,61 @@ +import pytest +from dify_trace_weave.config import WeaveConfig +from pydantic import ValidationError + + +class TestWeaveConfig: + """Test cases for WeaveConfig""" + + def test_valid_config(self): + """Test valid Weave configuration""" + config = WeaveConfig( + api_key="test_key", + entity="test_entity", + project="test_project", + endpoint="https://custom.wandb.ai", + host="https://custom.host.com", + ) + assert config.api_key == "test_key" + assert config.entity == "test_entity" + assert config.project == "test_project" + assert config.endpoint == "https://custom.wandb.ai" + assert config.host == "https://custom.host.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = WeaveConfig(api_key="key", project="project") + assert config.entity is None + assert config.endpoint == "https://trace.wandb.ai" + assert config.host is None + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + WeaveConfig() + + with pytest.raises(ValidationError): + WeaveConfig(api_key="key") + + with pytest.raises(ValidationError): + WeaveConfig(project="project") + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") + + def test_host_validation_optional(self): + """Test host validation is optional but validates when provided""" + config = WeaveConfig(api_key="key", project="project", host=None) + assert config.host is None + + config = WeaveConfig(api_key="key", project="project", host="") + assert config.host == "" + + config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") + assert config.host == "https://valid.host.com" + + def test_host_validation_invalid_scheme(self): + """Test host validation rejects invalid schemes when provided""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py similarity index 97% rename from api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py rename to api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py index 531c7de05f..6028d0c550 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" +"""Comprehensive tests for dify_trace_weave.weave_trace module.""" from __future__ import annotations @@ -7,9 +7,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel +from dify_trace_weave.weave_trace import WeaveDataTrace from weave.trace_server.trace_server_interface import TraceStatus -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -20,8 +22,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.ops.weave_trace.weave_trace import WeaveDataTrace from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -191,14 +191,14 @@ def _make_node(**overrides): @pytest.fixture def mock_wandb(): - with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + with patch("dify_trace_weave.weave_trace.wandb") as mock: mock.login.return_value = True yield mock @pytest.fixture def mock_weave(): - with patch("core.ops.weave_trace.weave_trace.weave") as mock: + with patch("dify_trace_weave.weave_trace.weave") as mock: client = MagicMock() client.entity = "my-entity" client.project = "my-project" @@ -307,7 +307,7 @@ class TestGetProjectUrl: monkeypatch.setattr(trace_instance, "entity", None) monkeypatch.setattr(trace_instance, "project_name", None) # Force an error by making string formatting fail - with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + with patch("dify_trace_weave.weave_trace.logger") as mock_logger: # Simulate exception via property original_entity = trace_instance.entity trace_instance.entity = None @@ -594,9 +594,9 @@ class TestWorkflowTrace: mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) return repo def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): @@ -703,8 +703,8 @@ class TestWorkflowTrace: def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): """Raises ValueError when app_id is missing from metadata.""" - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) trace_info = _make_workflow_trace_info( message_id=None, @@ -802,7 +802,7 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.get", + "dify_trace_weave.weave_trace.db.session.get", lambda model, pk: None, ) @@ -824,7 +824,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -846,7 +846,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = end_user - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -866,7 +866,7 @@ class TestMessageTrace: """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -884,7 +884,7 @@ class TestMessageTrace: """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -899,7 +899,7 @@ class TestMessageTrace: """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() diff --git a/api/pyproject.toml b/api/pyproject.toml index af8eb864b0..c349979a7c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -32,9 +32,6 @@ dependencies = [ "flask-restx>=1.3.2,<2.0.0", "google-cloud-aiplatform>=1.147.0,<2.0.0", "httpx[socks]>=0.28.1,<1.0.0", - "langfuse>=4.2.0,<5.0.0", - "langsmith>=0.7.31,<1.0.0", - "mlflow-skinny>=3.11.1,<4.0.0", "opentelemetry-distro>=0.62b0,<1.0.0", "opentelemetry-instrumentation-celery>=0.62b0,<1.0.0", "opentelemetry-instrumentation-flask>=0.62b0,<1.0.0", @@ -44,15 +41,12 @@ dependencies = [ "opentelemetry-propagator-b3>=1.41.0,<2.0.0", "readabilipy>=0.3.0,<1.0.0", "resend>=2.27.0,<3.0.0", - "weave>=0.52.36,<1.0.0", # Emerging: newer and fast-moving, use compatible pins - "arize-phoenix-otel~=0.15.0", "fastopenapi[flask]~=0.7.0", - "graphon~=0.1.2", + "graphon~=0.2.2", "httpx-sse~=0.4.0", "json-repair~=0.59.2", - "opik~=1.11.2", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -61,8 +55,8 @@ dependencies = [ packages = [] [tool.uv.workspace] -members = ["providers/vdb/*"] -exclude = ["providers/vdb/__pycache__"] +members = ["providers/vdb/*", "providers/trace/*"] +exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"] [tool.uv.sources] dify-vdb-alibabacloud-mysql = { workspace = true } @@ -95,9 +89,17 @@ dify-vdb-upstash = { workspace = true } dify-vdb-vastbase = { workspace = true } dify-vdb-vikingdb = { workspace = true } dify-vdb-weaviate = { workspace = true } +dify-trace-aliyun = { workspace = true } +dify-trace-arize-phoenix = { workspace = true } +dify-trace-langfuse = { workspace = true } +dify-trace-langsmith = { workspace = true } +dify-trace-mlflow = { workspace = true } +dify-trace-opik = { workspace = true } +dify-trace-tencent = { workspace = true } +dify-trace-weave = { workspace = true } [tool.uv] -default-groups = ["storage", "tools", "vdb-all"] +default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false override-dependencies = [ "pyarrow>=18.0.0", @@ -171,7 +173,7 @@ dev = [ # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.60.0", + "pyrefly>=0.61.1", "xinference-client>=2.4.0", ] @@ -272,6 +274,25 @@ vdb-weaviate = ["dify-vdb-weaviate"] # Optional client used by some tests / integrations (not a vector backend plugin) vdb-xinference = ["xinference-client>=2.4.0"] +trace-all = [ + "dify-trace-aliyun", + "dify-trace-arize-phoenix", + "dify-trace-langfuse", + "dify-trace-langsmith", + "dify-trace-mlflow", + "dify-trace-opik", + "dify-trace-tencent", + "dify-trace-weave", +] +trace-aliyun = ["dify-trace-aliyun"] +trace-arize-phoenix = ["dify-trace-arize-phoenix"] +trace-langfuse = ["dify-trace-langfuse"] +trace-langsmith = ["dify-trace-langsmith"] +trace-mlflow = ["dify-trace-mlflow"] +trace-opik = ["dify-trace-opik"] +trace-tencent = ["dify-trace-tencent"] +trace-weave = ["dify-trace-weave"] + [tool.pyrefly] project-includes = ["."] project-excludes = [".venv", "migrations/"] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index 3e5ece1fcf..fbbca24558 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -34,12 +34,12 @@ core/external_data_tool/api/api.py core/llm_generator/llm_generator.py core/llm_generator/output_parser/structured_output.py core/mcp/mcp_client.py -core/ops/aliyun_trace/data_exporter/traceclient.py -core/ops/arize_phoenix_trace/arize_phoenix_trace.py -core/ops/mlflow_trace/mlflow_trace.py +providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py +providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py core/ops/ops_trace_manager.py -core/ops/tencent_trace/client.py -core/ops/tencent_trace/utils.py +providers/trace/trace-tencent/src/dify_trace_tencent/client.py +providers/trace/trace-tencent/src/dify_trace_tencent/utils.py core/plugin/backwards_invocation/base.py core/plugin/backwards_invocation/model.py core/prompt/utils/extract_thread_messages.py diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index c4582e891d..ac0e2a3a53 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -5,7 +5,8 @@ ".venv", "migrations/", "core/rag", - "providers/", + "providers/vdb/", + "providers/trace/*/tests", ], "typeCheckingMode": "strict", "allowedUntypedLibraries": [ diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 8479cdfb0c..2cc0192a4a 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -7,8 +7,8 @@ from sqlalchemy import select import app from configs import dify_config +from core.db.session_factory import session_factory from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service from models import Account, Tenant, TenantAccountJoin @@ -33,67 +33,68 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = db.session.scalars( - select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) - ).all() - # group by tenant_id - dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) - for dataset_auto_disable_log in dataset_auto_disable_logs: - if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) - url = f"{dify_config.CONSOLE_WEB_URL}/datasets" - for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): - features = FeatureService.get_features(tenant_id) - plan = features.billing.subscription.plan - if plan != CloudPlan.SANDBOX: - knowledge_details = [] - # check tenant - tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) - if not tenant: - continue - # check current owner - current_owner_join = db.session.scalar( - select(TenantAccountJoin) - .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") - .limit(1) - ) - if not current_owner_join: - continue - account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) - if not account: - continue + with session_factory.create_session() as session: + dataset_auto_disable_logs = session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False)) + ).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != CloudPlan.SANDBOX: + knowledge_details = [] + # check tenant + tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + # check current owner + current_owner_join = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) + ) + if not current_owner_join: + continue + account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) + if not account: + continue - dataset_auto_dataset_map = {} # type: ignore + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id)) + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, + language_code="en-US", + to=account.email, + template_context={ + "userName": account.email, + "knowledge_details": knowledge_details, + "url": url, + }, + ) + + # update notified to True for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( - dataset_auto_disable_log.document_id - ) - - for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) - if dataset: - document_count = len(document_ids) - knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") - if knowledge_details: - email_service = get_email_i18n_service() - email_service.send_email( - email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, - language_code="en-US", - to=account.email, - template_context={ - "userName": account.email, - "knowledge_details": knowledge_details, - "url": url, - }, - ) - - # update notified to True - for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - dataset_auto_disable_log.notified = True - db.session.commit() + dataset_auto_disable_log.notified = True + session.commit() end_at = time.perf_counter() logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 74b800606d..70184b42de 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -23,6 +23,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from core.helper import ssrf_proxy from core.plugin.entities.plugin import PluginDependency from core.trigger.constants import ( @@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.6.0" +CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION class Import(BaseModel): diff --git a/api/services/app_service.py b/api/services/app_service.py index afd98e2975..038c59633a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e6f5f80a6d..894cb05687 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -30,7 +30,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from graphon.file import helpers as file_helpers from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/services/feature_service.py b/api/services/feature_service.py index e4eb9e7582..a6b02630bf 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -3,6 +3,7 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict, Field from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from enums.cloud_plan import CloudPlan from enums.hosted_provider import HostedTrialProvider from services.billing_service import BillingService @@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel): class SystemFeatureModel(BaseModel): + app_dsl_version: str = "" sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" enable_marketplace: bool = False @@ -225,6 +227,7 @@ class FeatureService: @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() + system_features.app_dsl_version = CURRENT_APP_DSL_VERSION cls._fulfill_system_params_from_env(system_features) diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 77576fa4c0..17b28bbe48 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -9,7 +9,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index aa7456dcd3..8c9a81af87 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -50,7 +50,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod @@ -60,5 +60,5 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param template_id: Template ID :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 0ffbef8365..9d446f6d4b 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TypedDict import yaml from sqlalchemy import select @@ -10,6 +10,30 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class CustomizedTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + position: int + chunk_structure: str + + +class CustomizedTemplatesResultDict(TypedDict): + pipeline_templates: list[CustomizedTemplateItemDict] + + +class CustomizedTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + created_by: str + + class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval recommended app from database @@ -17,12 +41,10 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): def get_pipeline_templates(self, language: str) -> dict[str, Any]: _, current_tenant_id = current_account_with_tenant() - result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) - return result + return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.CUSTOMIZED @@ -40,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) ).all() - recommended_pipelines_results = [] + recommended_pipelines_results: list[CustomizedTemplateItemDict] = [] for pipeline_customized_template in pipeline_customized_templates: - recommended_pipeline_result = { + recommended_pipeline_result: CustomizedTemplateItemDict = { "id": pipeline_customized_template.id, "name": pipeline_customized_template.name, "description": pipeline_customized_template.description, diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 073eed221c..2964537c35 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TypedDict import yaml from sqlalchemy import select @@ -9,18 +9,41 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class PipelineTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + copyright: str + privacy_policy: str + position: int + chunk_structure: str + + +class PipelineTemplatesResultDict(TypedDict): + pipeline_templates: list[PipelineTemplateItemDict] + + +class PipelineTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + + class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval pipeline template from database """ def get_pipeline_templates(self, language: str) -> dict[str, Any]: - result = self.fetch_pipeline_templates_from_db(language) - return result + return self.fetch_pipeline_templates_from_db(language) def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.DATABASE @@ -39,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ).all() ) - recommended_pipelines_results = [] + recommended_pipelines_results: list[PipelineTemplateItemDict] = [] for pipeline_built_in_template in pipeline_built_in_templates: - recommended_pipeline_result = { + recommended_pipeline_result: PipelineTemplateItemDict = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, "description": pipeline_built_in_template.description, diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index d5ef745bec..9565ac46cc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -17,21 +17,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result: dict[str, Any] | None try: - result = self.fetch_pipeline_template_detail_from_dify_official(template_id) + return self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) def get_pipeline_templates(self, language: str) -> dict[str, Any]: try: - result = self.fetch_pipeline_templates_from_dify_official(language) + return self.fetch_pipeline_templates_from_dify_official(language) except Exception as e: logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) def get_type(self) -> str: return PipelineTemplateType.REMOTE diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 605689226a..eac1093299 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -476,7 +476,7 @@ class RagPipelineService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 4d58a9cf12..56880f84af 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -170,7 +170,7 @@ class VariableTruncator(BaseTruncator): return TruncationResult(StringSegment(value=fallback_result.value), True) # Apply final fallback - convert to JSON string and truncate - json_str = dumps_with_segments(result.value, ensure_ascii=False) + json_str = dumps_with_segments(result.value) if len(json_str) > self._max_size_bytes: json_str = json_str[: self._max_size_bytes] + "..." return TruncationResult(result=StringSegment(value=json_str), truncated=True) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index fae5dea3cb..8c0878aa79 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -1075,7 +1075,7 @@ class DraftVariableSaver: filename = f"{self._generate_filename(name)}.txt" else: # For other types, store as JSON - original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False) + original_content_serialized = dumps_with_segments(value_seg.value) content_type = "application/json" filename = f"{self._generate_filename(name)}.json" diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0caa066485..95c931a46e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,6 +6,39 @@ from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast from sqlalchemy import and_, exists, or_, select +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.entities import PluginCredentialType +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager +from core.repositories import DifyCoreRepositoryFactory +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.trigger.constants import TRIGGER_NODE_TYPES, is_trigger_node_type +from core.workflow.human_input_adapter import ( + DeliveryChannelConfig, + adapt_human_input_node_data_for_graph, + parse_human_input_delivery_methods, +) +from core.workflow.node_factory import ( + LATEST_VERSION, + DifyGraphInitContext, + get_node_type_classes_mapping, + is_start_node_type, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace +from enums.cloud_plan import CloudPlan +from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated +from extensions.ext_database import db +from extensions.ext_storage import storage +from factories.file_factory import build_from_mapping, build_from_mappings from graphon.entities import WorkflowNodeExecution from graphon.entities.graph_config import NodeConfigDict from graphon.entities.pause_reason import HumanInputRequired @@ -31,53 +64,13 @@ from graphon.variable_loader import load_into_variable_pool from graphon.variables import VariableBase from graphon.variables.input_entities import VariableEntityType from graphon.variables.variables import Variable -from sqlalchemy import and_, exists, or_, select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager -from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.app.file_access import DatabaseFileAccessController -from core.entities import PluginCredentialType -from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager -from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.trigger.constants import TRIGGER_NODE_TYPES, is_trigger_node_type -from core.workflow.human_input_compat import ( - DeliveryChannelConfig, - normalize_human_input_node_data_for_graph, - parse_human_input_delivery_methods, -) -from core.workflow.node_factory import ( - LATEST_VERSION, - DifyGraphInitContext, - get_node_type_classes_mapping, - is_start_node_type, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from core.workflow.workflow_entry import WorkflowEntry -from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace -from enums.cloud_plan import CloudPlan -from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated -from extensions.ext_database import db -from extensions.ext_storage import storage -from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from libs.helper import escape_like_pattern from models import Account from models.human_input import HumanInputFormRecipient, RecipientType from models.model import App, AppMode from models.tools import WorkflowToolProvider -from models.workflow import ( - Workflow, - WorkflowKind, - WorkflowNodeExecutionModel, - WorkflowNodeExecutionTriggeredFrom, - WorkflowType, -) +from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService from services.errors.app import ( @@ -119,7 +112,8 @@ class WorkflowService: def __init__(self, session_maker: sessionmaker | None = None): """Initialize WorkflowService with repository dependencies.""" if session_maker is None: - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + session_maker = sessionmaker( + bind=db.engine, expire_on_commit=False) self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) @@ -223,7 +217,8 @@ class WorkflowService: if not app_ids: return set() - stmt = select(App.id).where(App.id.in_(app_ids), App.tenant_id == tenant_id) + stmt = select(App.id).where(App.id.in_( + app_ids), App.tenant_id == tenant_id) return {str(app_id) for app_id in db.session.scalars(stmt).all()} def get_all_published_workflow( @@ -283,7 +278,7 @@ class WorkflowService: """ stmt = select(Workflow).where( Workflow.tenant_id == tenant_id, - Workflow.kind == WorkflowKind.EVALUATION.value, + Workflow.type == WorkflowType.EVALUATION, Workflow.version != Workflow.VERSION_DRAFT, ) @@ -307,7 +302,8 @@ class WorkflowService: ) ) - stmt = stmt.order_by(Workflow.created_at.desc()).limit(limit + 1).offset((page - 1) * limit) + stmt = stmt.order_by(Workflow.created_at.desc()).limit( + limit + 1).offset((page - 1) * limit) workflows = session.scalars(stmt).all() @@ -339,7 +335,8 @@ class WorkflowService: raise WorkflowHashNotEqualError() # validate features structure - self.validate_features_structure(app_model=app_model, features=features) + self.validate_features_structure( + app_model=app_model, features=features) # validate graph structure self.validate_graph_structure(graph=graph) @@ -350,7 +347,6 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(app_model.mode).value, - kind=WorkflowKind.STANDARD.value, version=Workflow.VERSION_DRAFT, graph=json.dumps(graph), features=json.dumps(features), @@ -367,14 +363,13 @@ class WorkflowService: workflow.updated_at = naive_utc_now() workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables - if workflow.resolved_kind == WorkflowKind.STANDARD: - workflow.kind = WorkflowKind.STANDARD # commit db session changes db.session.commit() # trigger app workflow events - app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow) + app_draft_workflow_was_synced.send( + app_model, synced_draft_workflow=workflow) # return draft workflow return workflow @@ -442,7 +437,8 @@ class WorkflowService: raise ValueError("No draft workflow found.") # validate features structure - self.validate_features_structure(app_model=app_model, features=features) + self.validate_features_structure( + app_model=app_model, features=features) workflow.features = json.dumps(features) workflow.updated_by = account.id @@ -463,11 +459,13 @@ class WorkflowService: Secret environment variables are copied server-side from the selected published workflow so the normal draft sync flow stays stateless. """ - source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id) + source_workflow = self.get_published_workflow_by_id( + app_model=app_model, workflow_id=workflow_id) if not source_workflow: raise WorkflowNotFoundError("Workflow not found.") - self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict) + self.validate_features_structure( + app_model=app_model, features=source_workflow.normalized_features_dict) self.validate_graph_structure(graph=source_workflow.graph_dict) draft_workflow = self.get_draft_workflow(app_model=app_model) @@ -484,7 +482,8 @@ class WorkflowService: db.session.add(draft_workflow) db.session.commit() - app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow) + app_draft_workflow_was_synced.send( + app_model, synced_draft_workflow=draft_workflow) return draft_workflow @@ -528,7 +527,8 @@ class WorkflowService: and is_trigger_node_type(node_type_str) ) if trigger_node_count > 2: - raise TriggerNodeLimitExceededError(count=trigger_node_count, limit=2) + raise TriggerNodeLimitExceededError( + count=trigger_node_count, limit=2) # create new workflow workflow = Workflow.new( @@ -544,14 +544,14 @@ class WorkflowService: marked_comment=marked_comment, rag_pipeline_variables=draft_workflow.rag_pipeline_variables, features=draft_workflow.features, - kind=draft_workflow.kind_or_standard, ) # commit db session changes session.add(workflow) # trigger app workflow events - app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + app_published_workflow_was_updated.send( + app_model, published_workflow=workflow) # return new workflow return workflow @@ -568,7 +568,7 @@ class WorkflowService: """Publish draft workflow as an evaluation workflow version. Compared to standard publish: - - set business kind to ``evaluation``; + - force published workflow type to ``evaluation``; - reject graphs containing trigger or human-input nodes. """ draft_workflow_stmt = select(Workflow).where( @@ -593,7 +593,7 @@ class WorkflowService: workflow = Workflow.new( tenant_id=app_model.tenant_id, app_id=app_model.id, - type=draft_workflow.type.value, + type=WorkflowType.EVALUATION.value, version=Workflow.version_from_datetime(naive_utc_now()), graph=draft_workflow.graph, created_by=account.id, @@ -603,13 +603,13 @@ class WorkflowService: marked_comment=marked_comment, rag_pipeline_variables=draft_workflow.rag_pipeline_variables, features=draft_workflow.features, - kind=WorkflowKind.EVALUATION.value, ) session.add(workflow) # trigger app workflow events - app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + app_published_workflow_was_updated.send( + app_model, published_workflow=workflow) return workflow @@ -618,16 +618,17 @@ class WorkflowService: *, session: Session, app_model: App, - target_type: WorkflowKind, + target_type: WorkflowType, account: Account, ) -> Workflow: """ - Convert a published workflow business kind in-place. + Convert a published workflow type in-place. This endpoint only supports conversion between standard workflow and evaluation workflow. """ - if target_type not in {WorkflowKind.STANDARD, WorkflowKind.EVALUATION}: - raise ValueError("target_type must be either 'standard' or 'evaluation'") + if target_type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}: + raise ValueError( + "target_type must be either 'workflow' or 'evaluation'") if not app_model.workflow_id: raise WorkflowNotFoundError("Published workflow not found") @@ -642,19 +643,21 @@ class WorkflowService: raise WorkflowNotFoundError("Published workflow not found") if workflow.version == Workflow.VERSION_DRAFT: - raise IsDraftWorkflowError("Current effective workflow cannot be a draft version.") + raise IsDraftWorkflowError( + "Current effective workflow cannot be a draft version.") - if workflow.resolved_kind == target_type: + if workflow.type == target_type: return workflow - if target_type == WorkflowKind.EVALUATION: + if target_type == WorkflowType.EVALUATION: self._validate_evaluation_workflow_nodes(workflow) - workflow.kind = target_type + workflow.type = target_type workflow.updated_by = account.id workflow.updated_at = naive_utc_now() - app_published_workflow_was_updated.send(app_model, published_workflow=workflow) + app_published_workflow_was_updated.send( + app_model, published_workflow=workflow) return workflow @@ -672,7 +675,8 @@ class WorkflowService: if not disallowed_nodes: return - formatted_nodes = ", ".join(f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes) + formatted_nodes = ", ".join( + f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes) raise ValueError( "Evaluation workflow cannot contain trigger or human-input nodes. " f"Found disallowed nodes: {formatted_nodes}" @@ -710,12 +714,14 @@ class WorkflowService: ) else: # Check default workspace credential for this provider - self._check_default_tool_credential(workflow.tenant_id, provider) + self._check_default_tool_credential( + workflow.tenant_id, provider) elif node_type == "agent": agent_params = node_data.get("agent_parameters", {}) - model_config = agent_params.get("model", {}).get("value", {}) + model_config = agent_params.get( + "model", {}).get("value", {}) if model_config.get("provider") and model_config.get("model"): self._validate_llm_model_config( workflow.tenant_id, model_config["provider"], model_config["model"] @@ -723,7 +729,8 @@ class WorkflowService: # Validate load balancing credentials for agent model if load balancing is enabled agent_model_node_data = {"model": model_config} - self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id) + self._validate_load_balancing_credentials( + workflow, agent_model_node_data, node_id) # Validate agent tools tools = agent_params.get("tools", {}).get("value", []) @@ -735,9 +742,11 @@ class WorkflowService: if credential_id: from core.helper.credential_utils import check_credential_policy_compliance - check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL) + check_credential_policy_compliance( + credential_id, provider, PluginCredentialType.TOOL) else: - self._check_default_tool_credential(workflow.tenant_id, provider) + self._check_default_tool_credential( + workflow.tenant_id, provider) elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]: model_config = node_data.get("model", {}) @@ -746,11 +755,14 @@ class WorkflowService: if provider and model_name: # Validate that the provider+model combination can fetch valid credentials - self._validate_llm_model_config(workflow.tenant_id, provider, model_name) + self._validate_llm_model_config( + workflow.tenant_id, provider, model_name) # Validate load balancing credentials if load balancing is enabled - self._validate_load_balancing_credentials(workflow, node_data, node_id) + self._validate_load_balancing_credentials( + workflow, node_data, node_id) else: - raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration") + raise ValueError( + f"Node {node_id} ({node_type}): Missing provider or model configuration") except Exception as e: if isinstance(e, ValueError): @@ -792,8 +804,10 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_configurations = assembly.provider_manager.get_configurations(tenant_id) - models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) + provider_configurations = assembly.provider_manager.get_configurations( + tenant_id) + models = provider_configurations.get_models( + provider=provider, model_type=ModelType.LLM) target_model = None for model in models: @@ -804,7 +818,8 @@ class WorkflowService: if target_model: target_model.raise_for_status() else: - raise ValueError(f"Model {model_name} not found for provider {provider}") + raise ValueError( + f"Model {model_name} not found for provider {provider}") except Exception as e: raise ValueError( @@ -852,7 +867,8 @@ class WorkflowService: ) except Exception as e: - raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") + raise ValueError( + f"Failed to validate default credential for tool provider {provider}: {str(e)}") def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None: """ @@ -874,7 +890,8 @@ class WorkflowService: # Check if this model has load balancing enabled if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name): # Get all load balancing configurations for this model - load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name) + load_balancing_configs = self._get_load_balancing_configs( + workflow.tenant_id, provider, model_name) # Validate each load balancing configuration try: for config in load_balancing_configs: @@ -885,7 +902,8 @@ class WorkflowService: config["credential_id"], provider, PluginCredentialType.MODEL ) except Exception as e: - raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") + raise ValueError( + f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool: """ @@ -900,8 +918,10 @@ class WorkflowService: from graphon.model_runtime.entities.model_entities import ModelType # Get provider configurations - provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) - provider_configurations = provider_manager.get_configurations(tenant_id) + provider_manager = create_plugin_provider_manager( + tenant_id=tenant_id) + provider_configurations = provider_manager.get_configurations( + tenant_id) provider_configuration = provider_configurations.get(provider) if not provider_configuration: @@ -942,7 +962,8 @@ class WorkflowService: _, custom_configs = model_load_balancing_service.get_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model" ) - all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs) + all_configs = cast(list[dict[str, Any]], configs) + \ + cast(list[dict[str, Any]], custom_configs) return [config for config in all_configs if config.get("credential_id")] @@ -987,7 +1008,7 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -1006,7 +1027,8 @@ class WorkflowService: ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY, ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - default_config = node_class.get_default_config(filters=resolved_filters or None) + default_config = node_class.get_default_config( + filters=resolved_filters or None) if not default_config: return {} @@ -1029,7 +1051,8 @@ class WorkflowService: with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=account.id) + draft_var_srv.prefill_conversation_variable_default_values( + draft_workflow, user_id=account.id) node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) @@ -1043,7 +1066,8 @@ class WorkflowService: workflow=draft_workflow, ) if node_type == BuiltinNodeTypes.START: - start_data = StartNodeData.model_validate(node_data, from_attributes=True) + start_data = StartNodeData.model_validate( + node_data, from_attributes=True) user_inputs = _rebuild_file_for_user_inputs_in_start_node( tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs ) @@ -1078,7 +1102,8 @@ class WorkflowService: user_id=account.id, ) - enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id( + node_config) if enclosing_node_type_and_id: _, enclosing_node_id = enclosing_node_type_and_id else: @@ -1113,12 +1138,15 @@ class WorkflowService: ) repository.save(node_execution) - workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) + workflow_node_execution = self._node_execution_service_repo.get_execution_by_id( + node_execution.id) if workflow_node_execution is None: - raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") + raise ValueError( + f"WorkflowNodeExecution with id {node_execution.id} not found after saving") with sessionmaker(db.engine).begin() as session: - outputs = workflow_node_execution.load_full_outputs(session, storage) + outputs = workflow_node_execution.load_full_outputs( + session, storage) with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( @@ -1130,7 +1158,8 @@ class WorkflowService: node_execution_id=node_execution.id, user=account, ) - draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) + draft_var_saver.save( + process_data=node_execution.process_data, outputs=outputs) enqueue_draft_node_execution_trace( execution=workflow_node_execution, @@ -1257,7 +1286,8 @@ class WorkflowService: rendered_content, outputs, node_data.outputs_field_names() ) - enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id( + node_config) enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( @@ -1292,7 +1322,7 @@ class WorkflowService: raise ValueError("Node type must be human-input.") node_data = HumanInputNodeData.model_validate( - normalize_human_input_node_data_for_graph(node_config["data"]), + adapt_human_input_node_data_for_graph(node_config["data"]), from_attributes=True, ) delivery_method = self._resolve_human_input_delivery_method( @@ -1344,7 +1374,8 @@ class WorkflowService: try: test_service.send_test(context=context, method=delivery_method) except DeliveryTestUnsupportedError as exc: - raise ValueError("Delivery method does not support test send.") from exc + raise ValueError( + "Delivery method does not support test send.") from exc except DeliveryTestError as exc: raise ValueError(str(exc)) from exc @@ -1369,7 +1400,8 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id) + repo = HumanInputFormRepositoryImpl( + tenant_id=app_model.tenant_id, app_id=app_model.id) params = FormCreateParams( workflow_execution_id=None, node_id=node_id, @@ -1389,7 +1421,8 @@ class WorkflowService: with Session(bind=db.engine) as session: recipients = session.scalars( - select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id) + select(HumanInputFormRecipient).where( + HumanInputFormRecipient.form_id == form_id) ).all() recipients_data: list[DeliveryTestEmailRecipient] = [] for recipient in recipients: @@ -1400,11 +1433,13 @@ class WorkflowService: try: payload = json.loads(recipient.recipient_payload) except (json.JSONDecodeError, ValueError): - logger.exception("Failed to parse human input recipient payload for delivery test.") + logger.exception( + "Failed to parse human input recipient payload for delivery test.") continue email = payload.get("email") if email: - recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token)) + recipients_data.append(DeliveryTestEmailRecipient( + email=email, form_token=recipient.access_token)) return recipients_data def _build_human_input_node( @@ -1433,9 +1468,11 @@ class WorkflowService: variable_pool=variable_pool, start_at=time.perf_counter(), ) + node_data = HumanInputNode.validate_node_data( + adapt_human_input_node_data_for_graph(node_config["data"])) node = HumanInputNode( - id=node_config["id"], - config=node_config, + node_id=node_config["id"], + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=DifyHumanInputNodeRuntime(run_context), @@ -1453,7 +1490,8 @@ class WorkflowService: ) -> VariablePool: with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) + draft_var_srv.prefill_conversation_variable_default_values( + workflow, user_id=user_id) variable_pool = VariablePool() add_variables_to_pool( @@ -1531,7 +1569,8 @@ class WorkflowService: Returns: WorkflowNodeExecution: The execution result """ - node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn) + node, node_run_result, run_succeeded, error = self._execute_node_safely( + invoke_node_fn) # Create base node execution node_execution = WorkflowNodeExecution( @@ -1547,7 +1586,8 @@ class WorkflowService: ) # Populate execution result data - self._populate_execution_result(node_execution, node_run_result, run_succeeded, error) + self._populate_execution_result( + node_execution, node_run_result, run_succeeded, error) return node_execution @@ -1576,7 +1616,8 @@ class WorkflowService: # Apply error strategy if node failed if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.error_strategy: - node_run_result = self._apply_error_strategy(node, node_run_result) + node_run_result = self._apply_error_strategy( + node, node_run_result) run_succeeded = node_run_result.status in ( WorkflowNodeExecutionStatus.SUCCEEDED, @@ -1607,7 +1648,8 @@ class WorkflowService: status=WorkflowNodeExecutionStatus.EXCEPTION, error=node_run_result.error, inputs=node_run_result.inputs, - metadata={WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy}, + metadata={ + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy}, outputs=error_outputs, ) @@ -1621,10 +1663,12 @@ class WorkflowService: """Populate node execution with result data.""" if run_succeeded and node_run_result: node_execution.inputs = ( - WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None + WorkflowEntry.handle_special_values( + node_run_result.inputs) if node_run_result.inputs else None ) node_execution.process_data = ( - WorkflowEntry.handle_special_values(node_run_result.process_data) + WorkflowEntry.handle_special_values( + node_run_result.process_data) if node_run_result.process_data else None ) @@ -1653,7 +1697,8 @@ class WorkflowService: workflow_converter = WorkflowConverter() if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}: - raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") + raise ValueError( + f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow new_app: App = workflow_converter.convert_to_workflow( @@ -1690,7 +1735,8 @@ class WorkflowService: # start node and trigger node cannot coexist if BuiltinNodeTypes.START in node_types: if any(is_trigger_node_type(nt) for nt in node_types): - raise ValueError("Start node and trigger nodes cannot coexist in the same workflow") + raise ValueError( + "Start node and trigger nodes cannot coexist in the same workflow") for node in node_configs: node_data = node.get("data", {}) @@ -1725,7 +1771,8 @@ class WorkflowService: from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) + HumanInputNodeData.model_validate( + adapt_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") @@ -1742,7 +1789,8 @@ class WorkflowService: :param data: Dictionary containing fields to update :return: Updated workflow or None if not found """ - stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) + stmt = select(Workflow).where(Workflow.id == workflow_id, + Workflow.tenant_id == tenant_id) workflow = session.scalar(stmt) if not workflow: @@ -1771,7 +1819,8 @@ class WorkflowService: :raises: WorkflowInUseError if workflow is in use :raises: DraftWorkflowDeletionError if workflow is a draft version """ - stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id) + stmt = select(Workflow).where(Workflow.id == workflow_id, + Workflow.tenant_id == tenant_id) workflow = session.scalar(stmt) if not workflow: @@ -1779,14 +1828,16 @@ class WorkflowService: # Check if workflow is a draft version if workflow.version == Workflow.VERSION_DRAFT: - raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") + raise DraftWorkflowDeletionError( + "Cannot delete draft workflow versions") # Check if this workflow is currently referenced by an app app_stmt = select(App).where(App.workflow_id == workflow_id) app = session.scalar(app_stmt) if app: # Cannot delete a workflow that's currently in use by an app - raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'") + raise WorkflowInUseError( + f"Cannot delete workflow that is currently in use by app '{app.id}'") # Don't use workflow.tool_published as it's not accurate for specific workflow versions # Check if there's a tool provider using this specific workflow version @@ -1800,7 +1851,8 @@ class WorkflowService: if tool_provider: # Cannot delete a workflow that's published as a tool - raise WorkflowInUseError("Cannot delete workflow that is published as a tool") + raise WorkflowInUseError( + "Cannot delete workflow that is published as a tool") session.delete(workflow) return True @@ -1829,7 +1881,7 @@ def _setup_variable_pool( } # Only add chatflow-specific variables for chat-like workflow types. - if workflow.type != WorkflowType.WORKFLOW: + if workflow.type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}: system_variable_values.update( { "query": query, @@ -1849,11 +1901,13 @@ def _setup_variable_pool( build_bootstrap_variables( system_variables=system_variable, environment_variables=workflow.environment_variables, - conversation_variables=cast(list[Variable], conversation_variables), + conversation_variables=cast( + list[Variable], conversation_variables), ), ) if is_start_node_type(node_type): - add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) + add_node_inputs_to_pool( + variable_pool, node_id=node_id, inputs=user_inputs) return variable_pool @@ -1869,7 +1923,8 @@ def _rebuild_file_for_user_inputs_in_start_node( if variable.variable not in user_inputs: continue value = user_inputs[variable.variable] - file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type) + file = _rebuild_single_file( + tenant_id=tenant_id, value=value, variable_entity_type=variable.type) inputs_copy[variable.variable] = file return inputs_copy @@ -1877,15 +1932,18 @@ def _rebuild_file_for_user_inputs_in_start_node( def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]: if variable_entity_type == VariableEntityType.FILE: if not isinstance(value, dict): - raise ValueError(f"expected dict for file object, got {type(value)}") + raise ValueError( + f"expected dict for file object, got {type(value)}") return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller) elif variable_entity_type == VariableEntityType.FILE_LIST: if not isinstance(value, list): - raise ValueError(f"expected list for file list object, got {type(value)}") + raise ValueError( + f"expected list for file list object, got {type(value)}") if len(value) == 0: return [] if not isinstance(value[0], dict): - raise ValueError(f"expected dict for first element in the file list, got {type(value)}") + raise ValueError( + f"expected dict for first element in the file list, got {type(value)}") return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller) else: raise Exception("unreachable") diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index a316eec7b9..36c710d180 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -12,7 +12,7 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail from models.human_input import ( diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index b9f382eccf..b0cbc54db3 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,7 +12,6 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -29,6 +28,7 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType +from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3fdea10976..2392084c36 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,8 +1,8 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: @@ -70,19 +70,16 @@ def test_node_integration_minimal_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=_GP(), graph_runtime_state=_GS(vp), ) diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 4f41396c22..bb0c585dc8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -6,6 +6,7 @@ from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import NodeRunResult from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeNodeData from graphon.nodes.code.limits import CodeNodeLimits from graphon.runtime import GraphRuntimeState, VariablePool @@ -64,8 +65,8 @@ def init_code_node(code_config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( - id=str(uuid.uuid4()), - config=code_config, + node_id=str(uuid.uuid4()), + config=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index b1f937e738..b9f7b9575b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,11 +3,6 @@ import uuid from urllib.parse import urlencode import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -16,6 +11,11 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.http",) @@ -75,8 +75,8 @@ def init_http_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -192,6 +192,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,8 +203,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool - from core.workflow.system_variables import build_system_variables - # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), @@ -724,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=graph_config["nodes"][1], + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index f0f3fcead1..bf640c7b96 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.llm.node import LLMNode from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory @@ -76,8 +77,8 @@ def init_llm_node(config: dict) -> LLMNode: llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index fe512c2585..f2eabb86c3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,17 +3,17 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -70,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = ParameterExtractorNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 2d728569be..f87fbaf7ae 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -3,6 +3,7 @@ import uuid from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError @@ -87,8 +88,8 @@ def test_execute_template_transform(): assert graph is not None node = TemplateTransformNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index f9ec51ee10..a8e9422c1e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import StreamCompletedEvent from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool.tool_node import ToolNode from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -60,8 +61,8 @@ def init_tool_node(config: dict): tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 3b1570a9a8..e0e1a07089 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -9,7 +9,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 3ecf621095..241b638791 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -101,8 +101,8 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, + node_id="start", + config=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -116,8 +116,8 @@ def _build_graph( ], ) human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, + node_id="human", + config=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -130,8 +130,8 @@ def _build_graph( desc=None, ) end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, + node_id="end", + config=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index cc72dc1cf3..066b149b37 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase): file_related_id = related_id return File( - id=str(uuid4()), # Generate new UUID for File.id + file_id=str(uuid4()), # Generate new UUID for File.id tenant_id=tenant_id, - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=transfer_method, related_id=file_related_id, remote_url=remote_url, diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index aaf9a85d60..54b7afc018 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -271,7 +271,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from core.workflow.human_input_compat import DeliveryMethodType + from core.workflow.human_input_adapter import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index c46b8fba0b..598daa86e6 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -6,7 +6,7 @@ import pytest from graphon.enums import BuiltinNodeTypes from graphon.nodes.human_input.entities import HumanInputNodeData -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 21a54e909e..ed75363f3b 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -8,7 +8,7 @@ import pytest from sqlalchemy.engine import Engine from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 1b4dcf28ea..5dcebef141 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -13,7 +13,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 6ff3b19362..e91c0a0597 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: file_list = [ File( tenant_id="t1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="http://u", ) diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 740da1f1df..7b329e0c0d 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", filename="test.jpg", @@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", filename="test.jpg", diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 97fdf1a011..4fb8ecf784 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -20,7 +20,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.variables.types import SegmentType from werkzeug.exceptions import BadRequest, NotFound import services @@ -38,6 +37,9 @@ from controllers.service_api.app.conversation import ( ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError +from fields._value_type_serializer import serialize_value_type +from graphon.variables import StringSegment +from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ( @@ -284,6 +286,32 @@ class TestConversationVariableResponseModels: assert response.created_at == int(created_at.timestamp()) assert response.updated_at == int(created_at.timestamp()) + def test_variable_response_normalizes_string_value_type_alias(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER.value, + } + ) + + assert response.value_type == "number" + + def test_variable_response_normalizes_callable_exposed_type(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()), + } + ) + + assert response.value_type == "string" + + def test_serialize_value_type_supports_segments_and_mappings(self): + assert serialize_value_type(StringSegment(value="hello")) == "string" + assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number" + def test_variable_pagination_response(self): response = ConversationVariableInfiniteScrollPaginationResponse.model_validate( { diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 328cd12f12..82fcccdf40 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -12,8 +12,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: def create_test_file(self, file_id: str = "test_file_1") -> File: """Create a test File object""" return File( - id=file_id, - type=FileType.DOCUMENT, + file_id=file_id, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", filename=f"{file_id}.txt", diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index a04a7b7576..6104b8d6ca 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401 from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow import node_factory as node_factory_module from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph @@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]): def version(cls) -> str: return "1" - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) + def __init__( + self, + node_id: str, + config: _StubToolNodeData, + *, + graph_init_params, + graph_runtime_state, + **_kwargs: Any, + ) -> None: + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) def _get_error_strategy(self): return self._node_data.error_strategy @@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node + original_resolve_node_class = node_factory_module.resolve_workflow_node_class - def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) - node_data = typed_node_config["data"] - if node_data.type == BuiltinNodeTypes.TOOL: - return _StubToolNode( - id=str(typed_node_config["id"]), - config=typed_node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, typed_node_config) + def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + if node_type == BuiltinNodeTypes.TOOL: + return _StubToolNode + return original_resolve_node_class(node_type=node_type, node_version=node_version) - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class) def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index cddd03f4b0..701863b927 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -26,8 +26,8 @@ def _build_file( extension: str | None = None, ) -> File: return File( - id="file-id", - type=FileType.IMAGE, + file_id="file-id", + file_type=FileType.IMAGE, transfer_method=transfer_method, reference=reference, remote_url=remote_url, @@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M assert runtime.multimodal_send_format == "url" - with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get: assert runtime.http_get("http://example", follow_redirects=False) == "response" mock_get.assert_called_once_with("http://example", follow_redirects=False) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index c4bfb23272..30a068f4c5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): - self.id = id + def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = node_id self.config = config self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d338cadb77..af54aed363 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", extension=".png", @@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", extension=".pdf", diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index a0b2820157..aeca2e3afd 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None: # Assert assert simple_provider.provider == "openai" - assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.label.en_us == "OpenAI" assert simple_provider.supported_model_types == [ModelType.LLM] diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index fe2c226843..a28143026f 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) ] ) - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="encrypted-old-key" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): - with patch( - "core.entities.provider_configuration.encrypter.encrypt_token", - side_effect=lambda tenant_id, value: f"enc::{value}", - ): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, - credential_id="credential-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + ) assert validated["openai_api_key"] == "enc::restored-key" assert validated["region"] == "us" @@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) -def test_validate_provider_credentials_opens_session_when_not_passed() -> None: +def test_validate_provider_credentials_without_credential_id() -> None: configuration = _build_provider_configuration() - mock_session = Mock() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.Session") as mock_session_cls: - with patch("core.entities.provider_configuration.db") as mock_db: - mock_db.engine = Mock() - mock_session_cls.return_value.__enter__.return_value = mock_session - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory - ): - validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} - mock_session_cls.assert_called_once() def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: @@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non def test_validate_provider_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-key"} @@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback( def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( encrypted_config='{"openai_api_key":"enc"}' ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) - assert validated == {"openai_api_key": "enc-new"} - - session = Mock() - mock_factory = Mock() - mock_factory.model_credentials_validate.return_value = {"region": "us"} - with _patched_session(session): + with _patched_session(mock_session): with patch( "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory ): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"region": "us"}, - ) + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) + assert validated == {"openai_api_key": "enc-new"} + + mock_factory2 = Mock() + mock_factory2.model_credentials_validate.return_value = {"region": "us"} + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) assert validated == {"region": "us"} @@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = None + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} @@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index bb6e40e224..8cb0938575 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType def test_file(): file = File( - id="test-file", + file_id="test-file", tenant_id="test-tenant-id", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="test-related-id", filename="image.png", @@ -25,27 +25,21 @@ def test_file(): assert file.size == 67 -def test_file_model_validate_accepts_legacy_tenant_id(): - data = { - "id": "test-file", - "tenant_id": "test-tenant-id", - "type": "image", - "transfer_method": "tool_file", - "related_id": "test-related-id", - "filename": "image.png", - "extension": ".png", - "mime_type": "image/png", - "size": 67, - "storage_key": "test-storage-key", - "url": "https://example.com/image.png", - # Extra legacy fields - "tool_file_id": "tool-file-123", - "upload_file_id": "upload-file-456", - "datasource_file_id": "datasource-file-789", - } +def test_file_constructor_accepts_legacy_tenant_id(): + file = File( + file_id="test-file", + tenant_id="test-tenant-id", + file_type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + tool_file_id="tool-file-123", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) - file = File.model_validate(data) - - assert file.related_id == "test-related-id" + assert file.related_id == "tool-file-123" assert file.storage_key == "test-storage-key" assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 3b5c5e6597..d9fed9ae2a 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -1,11 +1,17 @@ from unittest.mock import MagicMock, patch +import httpx import pytest from core.helper.ssrf_proxy import ( SSRF_DEFAULT_MAX_RETRIES, + SSRFProxy, _get_user_provided_host_header, + _to_graphon_http_response, + graphon_ssrf_proxy, make_request, + max_retries_exceeded_error, + request_error, ) @@ -174,3 +180,56 @@ class TestFollowRedirectsParameter: call_kwargs = mock_client.request.call_args.kwargs assert call_kwargs.get("follow_redirects") is True + + +def test_to_graphon_http_response_preserves_httpx_response_fields() -> None: + response = httpx.Response( + 201, + headers={"X-Test": "1"}, + content=b"payload", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + wrapped = _to_graphon_http_response(response) + + assert wrapped.status_code == 201 + assert wrapped.headers == {"x-test": "1", "content-length": "7"} + assert wrapped.content == b"payload" + assert wrapped.url == "https://example.com/resource" + assert wrapped.reason_phrase == "Created" + assert wrapped.text == "payload" + + +def test_ssrf_proxy_exposes_expected_error_types() -> None: + proxy = SSRFProxy() + + assert proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert proxy.request_error is request_error + assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert graphon_ssrf_proxy.request_error is request_error + + +@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"]) +def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None: + response = httpx.Response( + 200, + headers={"X-Test": "1"}, + content=b"ok", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method: + wrapped = getattr(graphon_ssrf_proxy, method_name)( + "https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + + mock_method.assert_called_once_with( + url="https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + assert wrapped.status_code == 200 + assert wrapped.url == "https://example.com/resource" + assert wrapped.content == b"ok" diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 249ecb5006..c4fd970562 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 2cbff54c42..69650c85cc 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -1,16 +1,11 @@ -import pytest -from pydantic import ValidationError +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_opik.config import OpikConfig +from dify_trace_weave.config import WeaveConfig -from core.ops.entities.config_entity import ( - AliyunConfig, - ArizeConfig, - LangfuseConfig, - LangSmithConfig, - OpikConfig, - PhoenixConfig, - TracingProviderEnum, - WeaveConfig, -) +from core.ops.entities.config_entity import TracingProviderEnum class TestTracingProviderEnum: @@ -27,349 +22,8 @@ class TestTracingProviderEnum: assert TracingProviderEnum.ALIYUN == "aliyun" -class TestArizeConfig: - """Test cases for ArizeConfig""" - - def test_valid_config(self): - """Test valid Arize configuration""" - config = ArizeConfig( - api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" - ) - assert config.api_key == "test_key" - assert config.space_id == "test_space" - assert config.project == "test_project" - assert config.endpoint == "https://custom.arize.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = ArizeConfig() - assert config.api_key is None - assert config.space_id is None - assert config.project is None - assert config.endpoint == "https://otlp.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = ArizeConfig(project="") - assert config.project == "default" - - def test_project_validation_none(self): - """Test project validation with None value""" - config = ArizeConfig(project=None) - assert config.project == "default" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = ArizeConfig(endpoint="") - assert config.endpoint == "https://otlp.arize.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") - assert config.endpoint == "https://custom.arize.com" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="ftp://invalid.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="invalid.com") - - -class TestPhoenixConfig: - """Test cases for PhoenixConfig""" - - def test_valid_config(self): - """Test valid Phoenix configuration""" - config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.phoenix.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = PhoenixConfig() - assert config.api_key is None - assert config.project is None - assert config.endpoint == "https://app.phoenix.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = PhoenixConfig(project="") - assert config.project == "default" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation with path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") - assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" - - def test_endpoint_validation_without_path(self): - """Test endpoint validation without path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") - assert config.endpoint == "https://app.phoenix.arize.com" - - -class TestLangfuseConfig: - """Test cases for LangfuseConfig""" - - def test_valid_config(self): - """Test valid Langfuse configuration""" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == "https://custom.langfuse.com" - - def test_valid_config_with_path(self): - host = "https://custom.langfuse.com/api/v1" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == host - - def test_default_values(self): - """Test default values are set correctly""" - config = LangfuseConfig(public_key="public", secret_key="secret") - assert config.host == "https://api.langfuse.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangfuseConfig() - - with pytest.raises(ValidationError): - LangfuseConfig(public_key="public") - - with pytest.raises(ValidationError): - LangfuseConfig(secret_key="secret") - - def test_host_validation_empty(self): - """Test host validation with empty value""" - config = LangfuseConfig(public_key="public", secret_key="secret", host="") - assert config.host == "https://api.langfuse.com" - - -class TestLangSmithConfig: - """Test cases for LangSmithConfig""" - - def test_valid_config(self): - """Test valid LangSmith configuration""" - config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.smith.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = LangSmithConfig(api_key="key", project="project") - assert config.endpoint == "https://api.smith.langchain.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangSmithConfig() - - with pytest.raises(ValidationError): - LangSmithConfig(api_key="key") - - with pytest.raises(ValidationError): - LangSmithConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") - - -class TestOpikConfig: - """Test cases for OpikConfig""" - - def test_valid_config(self): - """Test valid Opik configuration""" - config = OpikConfig( - api_key="test_key", - project="test_project", - workspace="test_workspace", - url="https://custom.comet.com/opik/api/", - ) - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.workspace == "test_workspace" - assert config.url == "https://custom.comet.com/opik/api/" - - def test_default_values(self): - """Test default values are set correctly""" - config = OpikConfig() - assert config.api_key is None - assert config.project is None - assert config.workspace is None - assert config.url == "https://www.comet.com/opik/api/" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = OpikConfig(project="") - assert config.project == "Default Project" - - def test_url_validation_empty(self): - """Test URL validation with empty value""" - config = OpikConfig(url="") - assert config.url == "https://www.comet.com/opik/api/" - - def test_url_validation_missing_suffix(self): - """Test URL validation requires /api/ suffix""" - with pytest.raises(ValidationError, match="URL should end with /api/"): - OpikConfig(url="https://custom.comet.com/opik/") - - def test_url_validation_invalid_scheme(self): - """Test URL validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - OpikConfig(url="ftp://custom.comet.com/opik/api/") - - -class TestWeaveConfig: - """Test cases for WeaveConfig""" - - def test_valid_config(self): - """Test valid Weave configuration""" - config = WeaveConfig( - api_key="test_key", - entity="test_entity", - project="test_project", - endpoint="https://custom.wandb.ai", - host="https://custom.host.com", - ) - assert config.api_key == "test_key" - assert config.entity == "test_entity" - assert config.project == "test_project" - assert config.endpoint == "https://custom.wandb.ai" - assert config.host == "https://custom.host.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = WeaveConfig(api_key="key", project="project") - assert config.entity is None - assert config.endpoint == "https://trace.wandb.ai" - assert config.host is None - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - WeaveConfig() - - with pytest.raises(ValidationError): - WeaveConfig(api_key="key") - - with pytest.raises(ValidationError): - WeaveConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") - - def test_host_validation_optional(self): - """Test host validation is optional but validates when provided""" - config = WeaveConfig(api_key="key", project="project", host=None) - assert config.host is None - - config = WeaveConfig(api_key="key", project="project", host="") - assert config.host == "" - - config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") - assert config.host == "https://valid.host.com" - - def test_host_validation_invalid_scheme(self): - """Test host validation rejects invalid schemes when provided""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") - - -class TestAliyunConfig: - """Test cases for AliyunConfig""" - - def test_valid_config(self): - """Test valid Aliyun configuration""" - config = AliyunConfig( - app_name="test_app", - license_key="test_license_key", - endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", - ) - assert config.app_name == "test_app" - assert config.license_key == "test_license_key" - assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - assert config.app_name == "dify_app" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - AliyunConfig() - - with pytest.raises(ValidationError): - AliyunConfig(license_key="test_license") - - with pytest.raises(ValidationError): - AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_app_name_validation_empty(self): - """Test app_name validation with empty value""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" - ) - assert config.app_name == "dify_app" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = AliyunConfig(license_key="test_license", endpoint="") - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation preserves path for Aliyun endpoints""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - ) - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_license_key_required(self): - """Test that license_key is required and cannot be empty""" - with pytest.raises(ValidationError): - AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_valid_endpoint_format_examples(self): - """Test valid endpoint format examples from comments""" - valid_endpoints = [ - # cms2.0 public endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", - # cms2.0 intranet endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", - # xtrace public endpoint - "http://tracing-cn-heyuan.arms.aliyuncs.com", - # xtrace intranet endpoint - "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", - ] - - for endpoint in valid_endpoints: - config = AliyunConfig(license_key="test_license", endpoint=endpoint) - assert config.endpoint == endpoint - - class TestConfigIntegration: - """Integration tests for configuration classes""" + """Cross-provider configuration sanity checks""" def test_all_configs_can_be_instantiated(self): """Test that all config classes can be instantiated with valid data""" @@ -388,7 +42,6 @@ class TestConfigIntegration: def test_url_normalization_consistency(self): """Test that URL normalization works consistently across configs""" - # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index 68aa130518..88bf555594 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -56,7 +56,7 @@ class TestPluginModelRuntime: assert len(providers) == 1 assert providers[0].provider == "langgenius/openai/openai" assert providers[0].provider_name == "openai" - assert providers[0].label.en_US == "OpenAI" + assert providers[0].label.en_us == "OpenAI" client.fetch_model_providers.assert_called_once_with("tenant") def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index 90730dff5a..014853e480 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -466,7 +466,7 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): file_param = File( tenant_id="tenant-1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/file.png", storage_key="", @@ -499,14 +499,14 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): file_one = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/a.txt", storage_key="", ) file_two = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/b.txt", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 2b280dd674..6fee5790cc 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg files = [ File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="", @@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files(): completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query(): memory = MagicMock(spec=TokenBufferMemory) prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 4a54649b28..28966242d8 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,19 +1,18 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 9f9ea33695..5308c8e7b3 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from graphon.model_runtime.entities.message_entities import UserPromptMessage # from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule # from graphon.model_runtime.entities.provider_entities import ProviderEntity -# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 64eb89590a..0220fb6d4a 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,12 +1,14 @@ """Primarily used for testing merged cell scenarios""" +import gc import io import os import tempfile +import warnings from collections import UserDict from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from docx import Document @@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): WordExtractor("not-a-file", "tenant", "user") -def test_del_closes_temp_file(): +def test_close_closes_temp_file(): extractor = object.__new__(WordExtractor) + extractor._closed = False extractor.temp_file = MagicMock() - WordExtractor.__del__(extractor) + extractor.close() extractor.temp_file.close.assert_called_once() +def test_close_is_idempotent(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + + extractor.close() + extractor.close() + + extractor.temp_file.close.assert_called_once() + + +def test_close_handles_async_close_mock(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + extractor.temp_file.close = AsyncMock() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + extractor.close() + gc.collect() + + extractor.temp_file.close.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] + + def test_extract_images_handles_invalid_external_cases(monkeypatch): class FakeTargetRef: def __contains__(self, item): diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 0fc82dda53..5bbcd2202d 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -19,7 +19,7 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 1297a95df1..4248782d93 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -21,7 +21,7 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index ac65d0c02b..00a3a89c0b 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -7,9 +7,9 @@ from models.workflow import Workflow def test_file_to_dict(): file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="storage_key", diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 7406b88270..9e07ea1b6d 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,23 +1,30 @@ import dataclasses +from typing import Annotated import orjson import pytest +from pydantic import BaseModel, Discriminator, Tag + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayAnySegment, + ArrayBooleanSegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, NoneSegment, ObjectSegment, Segment, - SegmentUnion, StringSegment, get_segment_discriminator, ) @@ -42,11 +49,26 @@ from graphon.variables.variables import ( StringVariable, Variable, ) -from pydantic import BaseModel +from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool +type SegmentUnion = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] + ), + Discriminator(get_segment_discriminator), +] def _build_variable_pool( @@ -123,7 +145,7 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, @@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad: assert restored == model def test_all_segments_serialization(self): - """Test serialization/deserialization of all segment types""" + """Test file-aware segment serialization through Dify's model boundary.""" # Create one instance of each segment type test_file = create_test_file() @@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Segments(segments=all_segments) json_str = model.model_dump_json() - loaded = _Segments.model_validate_json(json_str) + loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all segments are preserved assert len(loaded.segments) == len(all_segments) @@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad: assert loaded_segment.value == original.value def test_all_variables_serialization(self): - """Test serialization/deserialization of all variable types""" + """Test file-aware variable serialization through Dify's model boundary.""" # Create one instance of each variable type test_file = create_test_file() @@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Variables(variables=all_variables) json_str = model.model_dump_json() - loaded = _Variables.model_validate_json(json_str) + loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all variables are preserved assert len(loaded.variables) == len(all_variables) diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 09254e17a3..1752e27271 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -34,7 +34,7 @@ def create_test_file( """Factory function to create File objects for testing.""" return File( tenant_id="test-tenant", - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 88989db856..9f3e3b00b9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -1,18 +1,18 @@ -""" -Mock node factory for testing workflows with third-party service dependencies. +"""Mock node factory for third-party-service workflow tests. -This module provides a MockNodeFactory that automatically detects and mocks nodes -requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). +The factory follows the same config adaptation path as production +`DifyNodeFactory.create_node()`, but swaps selected node classes for mock +implementations before instantiation. """ from typing import TYPE_CHECKING, Any +from core.workflow.human_input_adapter import adapt_node_config_for_graph +from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import DifyNodeFactory - from .test_mock_nodes import ( MockAgentNode, MockCodeNode, @@ -83,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory): :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) + node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = typed_node_config["id"] - # Create mock node instance mock_class = self._mock_node_types[node_type] + resolved_node_data = self._validate_resolved_node_data(mock_class, node_data) if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -105,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory): ) elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -121,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory): BuiltinNodeTypes.PARAMETER_EXTRACTOR, }: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -131,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory): ) else: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -141,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(typed_node_config) + return super().create_node(node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 8b7fbd1b30..27938e084b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -56,13 +56,14 @@ class MockNodeMixin: def __init__( self, - id: str, - config: Mapping[str, Any], + node_id: str, + config: Any, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, **kwargs: Any, - ): + ) -> None: if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) @@ -97,7 +98,7 @@ class MockNodeMixin: kwargs.setdefault("message_transformer", MagicMock()) super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 8311a1e847..d3272936cd 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -140,8 +140,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( - id=start_config["id"], - config=start_config, + node_id=start_config["id"], + config=StartNodeData(title="Start", variables=[]), graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -155,8 +155,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_a_config = {"id": "human_a", "data": human_data.model_dump()} human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, + node_id=human_a_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -165,8 +165,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_b_config = {"id": "human_b", "data": human_data.model_dump()} human_b = HumanInputNode( - id=human_b_config["id"], - config=human_b_config, + node_id=human_b_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -183,8 +183,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor ) end_config = {"id": "end", "data": end_data.model_dump()} end_node = EndNode( - id=end_config["id"], - config=end_config, + node_id=end_config["id"], + config=end_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 7195471eb6..76b4cd1ef4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,15 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.nodes.answer.entities import AnswerNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -67,20 +67,15 @@ def test_execute_answer(): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - } - node = AnswerNode( - id=str(uuid.uuid4()), + node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, + config=AnswerNodeData( + title="123", + type="answer", + answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + ), ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index fb03ae9998..d7ef781732 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,8 +1,8 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: @@ -78,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=gp, graph_runtime_state=gs, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 4705b3f76e..2e89a2da3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config(): assert default_config["retry_config"]["max_retries"] == 7 -def test_get_default_config_with_malformed_http_request_config_raises_value_error(): - with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"): +def test_get_default_config_with_malformed_http_request_config_raises_type_error(): + with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"): HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"}) @@ -114,8 +114,8 @@ def _build_http_node( start_at=time.perf_counter(), ) return HttpRequestNode( - id="http-node", - config=node_config, + node_id="http-node", + config=HttpRequestNodeData.model_validate(node_data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d16e1233ac..07430498e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,7 +1,6 @@ +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients - def test_render_body_template_replaces_variable_values(): config = EmailDeliveryConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index c0e21d0bf7..0659984c76 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -19,7 +19,7 @@ from core.repositories.human_input_repository import ( HumanInputFormRecipientEntity, HumanInputFormRepository, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryMethodType, EmailDeliveryConfig, EmailDeliveryMethod, @@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): entity.status_value = HumanInputFormStatus.SUBMITTED +def _build_human_input_node( + *, + node_id: str, + node_data: HumanInputNodeData | Mapping[str, Any], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + runtime: DifyHumanInputNodeRuntime, +) -> HumanInputNode: + typed_node_data = ( + node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data) + ) + return HumanInputNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + runtime=runtime, + ) + + class TestDeliveryMethod: """Test DeliveryMethod entity.""" @@ -239,7 +259,7 @@ class TestUserAction: data[field_name] = value with pytest.raises(ValidationError) as exc_info: - UserAction(**data) + UserAction.model_validate(data) errors = exc_info.value.errors() assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) @@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent: form_repository = InMemoryHumanInputFormRepository() runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 52802c7ce1..bbb4fa5f69 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -8,6 +8,7 @@ from graphon.graph_events import ( NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) +from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool @@ -26,6 +27,28 @@ class _FakeFormRepository: return self._form +def _create_human_input_node( + *, + config: dict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + repo: _FakeFormRepository, +) -> HumanInputNode: + node_data = ( + config["data"] + if isinstance(config["data"], HumanInputNodeData) + else HumanInputNodeData.model_validate(config["data"]) + ) + return HumanInputNode( + node_id=config["id"], + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + ) + + def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( @@ -81,13 +104,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) @@ -146,13 +167,11 @@ def _build_timeout_node() -> HumanInputNode: ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 82cc734274..8ffce39cd6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -5,6 +5,7 @@ import pytest from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams +from graphon.nodes.iteration.entities import IterationNodeData from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode from graphon.runtime import ( @@ -44,17 +45,14 @@ def _build_iteration_node( ) -> IterationNode: init_params = build_test_graph_init_params(graph_config=graph_config) return IterationNode( - id="iteration-node", - config={ - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration-node", "output"], - "start_node_id": start_node_id, - }, - }, + node_id="iteration-node", + config=IterationNodeData( + type="iteration", + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id=start_node_id, + ), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index a6fca1bfb4..f254fc3d09 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -93,6 +93,25 @@ def sample_chunks(): } +def _build_node( + *, + node_id: str, + node_data: KnowledgeIndexNodeData | dict[str, object], + graph_init_params, + graph_runtime_state, +) -> KnowledgeIndexNode: + return KnowledgeIndexNode( + node_id=node_id, + config=( + node_data + if isinstance(node_data, KnowledgeIndexNodeData) + else KnowledgeIndexNodeData.model_validate(node_data) + ), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + class TestKnowledgeIndexNode: """ Test suite for KnowledgeIndexNode. @@ -115,9 +134,9 @@ class TestKnowledgeIndexNode: } # Act - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -143,9 +162,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -176,9 +195,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -212,9 +231,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -269,9 +288,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -332,9 +351,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -383,9 +402,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -440,9 +459,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -498,9 +517,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -536,9 +555,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -583,9 +602,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index ab64be59ad..3ab86001e8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -18,7 +18,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( SingleRetrievalConfig, ) from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import ( + KnowledgeRetrievalNode, + _normalize_metadata_filter_scalar, + _normalize_metadata_filter_sequence_item, +) from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -85,6 +89,12 @@ def sample_node_data(): ) +def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None: + assert _normalize_metadata_filter_scalar(3) == 3 + assert _normalize_metadata_filter_scalar(True) == "True" + assert _normalize_metadata_filter_sequence_item(4) == "4" + + class TestKnowledgeRetrievalNode: """ Test suite for KnowledgeRetrievalNode. @@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode: # Act node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -470,8 +480,8 @@ class TestFetchDatasetRetriever: config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -507,8 +517,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -562,8 +572,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -610,8 +620,8 @@ class TestFetchDatasetRetriever: mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -671,8 +681,8 @@ class TestFetchDatasetRetriever: node_id = str(uuid.uuid4()) config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index fdf1706765..c0332fe99a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,8 +1,10 @@ +from types import SimpleNamespace from unittest.mock import MagicMock import pytest from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.entities import ListOperatorNodeData from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment @@ -13,11 +15,28 @@ from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" + @staticmethod + def _build_node(*, config, graph_init_params, graph_runtime_state): + return ListOperatorNode( + node_id="test", + config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + @staticmethod + def _filter_by(comparison_operator: str, value: str) -> dict[str, object]: + return { + "enabled": True, + "conditions": [{"comparison_operator": comparison_operator, "value": value}], + } + @pytest.fixture def mock_graph_runtime_state(self): """Create mock GraphRuntimeState.""" mock_state = MagicMock(spec=GraphRuntimeState) mock_variable_pool = MagicMock() + mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value) mock_state.variable_pool = mock_variable_pool return mock_state @@ -45,9 +64,8 @@ class TestListOperatorNode: def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable - return ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + return self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -64,9 +82,8 @@ class TestListOperatorNode: "limit": {"enabled": False}, } - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -109,9 +126,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=[]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -128,11 +144,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "contains", - "value": "app", - }, + "filter_by": self._filter_by("contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -140,9 +152,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -157,11 +168,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "not contains", - "value": "app", - }, + "filter_by": self._filter_by("not contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -169,9 +176,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -186,11 +192,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "5", - }, + "filter_by": self._filter_by(">", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -198,9 +200,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -226,9 +227,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -254,9 +254,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -282,9 +281,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -299,11 +297,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "3", - }, + "filter_by": self._filter_by(">", "3"), "order_by": { "enabled": True, "value": "desc", @@ -317,9 +311,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -341,9 +334,8 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = None - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -366,9 +358,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["first", "middle", "last"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -384,11 +375,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "start with", - "value": "app", - }, + "filter_by": self._filter_by("start with", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -396,9 +383,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -413,11 +399,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "end with", - "value": "le", - }, + "filter_by": self._filter_by("end with", "le"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -425,9 +407,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -442,11 +423,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "=", - "value": "5", - }, + "filter_by": self._filter_by("=", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -454,9 +431,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -471,11 +447,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "≠", - "value": "5", - }, + "filter_by": self._filter_by("≠", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -483,9 +455,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -511,9 +482,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 4186bbdc93..212ad07bd3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -71,8 +71,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -95,6 +95,8 @@ def variable_pool() -> VariablePool: def _fetch_prompt_messages_with_mocked_content(content): variable_pool = VariablePool.empty() model_instance = mock.MagicMock(spec=ModelInstance) + model_schema = mock.MagicMock() + model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text" prompt_template = [ LLMNodeChatModelMessage( text="You are a classifier.", @@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( "graphon.nodes.llm.llm_utils.fetch_model_schema", - return_value=mock.MagicMock(features=[]), + return_value=model_schema, ), mock.patch( "graphon.nodes.llm.llm_utils.handle_list_messages", diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 7841bf05ad..453c9dba4d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -140,8 +140,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -205,14 +205,10 @@ def llm_node( mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers(): def test_fetch_files_with_file_segment(): file = File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment(): def test_fetch_files_with_array_file_segment(): files = [ File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", storage_key="", ), File( - id="2", - type=FileType.IMAGE, + file_id="2", + file_type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", @@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/png", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/jpg", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: image_b64_data = base64.b64encode(image_raw_data).decode() mock_saved_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", extension=".png", @@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable file_saver=file_saver, file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, reasoning_format="separated", ) ) @@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=_build_prepared_llm_mock(), reasoning_format="separated", request_start_time=1.0, @@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors(): file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=model_instance, ) ) diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index bc44ececd8..892f6cc586 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -13,6 +13,28 @@ from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params +def _build_template_transform_node( + *, + node_data, + graph_init_params, + graph_runtime_state, + node_id: str = "test_node", + **kwargs, +) -> TemplateTransformNode: + typed_node_data = ( + node_data + if isinstance(node_data, TemplateTransformNodeData) + else TemplateTransformNodeData.model_validate(node_data) + ) + return TemplateTransformNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + **kwargs, + ) + + class TestTemplateTransformNode: """Comprehensive test suite for TemplateTransformNode.""" @@ -59,9 +81,8 @@ class TestTemplateTransformNode: def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -75,9 +96,8 @@ class TestTemplateTransformNode: def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -88,9 +108,8 @@ class TestTemplateTransformNode: def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -108,9 +127,8 @@ class TestTemplateTransformNode: } mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -143,9 +161,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() with pytest.raises(ValueError, match="max_output_length must be a positive integer"): - TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -170,9 +187,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -198,9 +214,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Value: " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -218,9 +233,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -238,9 +252,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -260,9 +273,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "1234567890" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -302,9 +314,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -375,8 +386,8 @@ class TestTemplateTransformNode: ) assert mapping == { - "node_123.var1": ["sys", "input1"], - "node_123.empty_selector": [], + "node_123.var1": ("sys", "input1"), + "node_123.empty_selector": (), } def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): @@ -409,9 +420,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a static message." - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -448,9 +458,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Total: $31.5" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -477,9 +486,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -507,9 +515,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index 636237e56e..a846efbb43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -4,6 +4,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, @@ -37,15 +38,13 @@ def mock_graph_runtime_state(): def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): node = TemplateTransformNode( - id="test_node", - config={ - "id": "test_node", - "data": { - "title": "Template Transform", - "variables": [], - "template": "hello", - }, - }, + node_id="test_node", + config=TemplateTransformNodeData( + title="Template Transform", + type="template-transform", + variables=[], + template="hello", + ), graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=MagicMock(), @@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri assert mapping == { "node_123.validated": ["sys", "input1"], - "node_123.raw": ["sys", "input2"], + "node_123.raw": ("sys", "input2"), } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index e11ebf6eb8..824f52a89b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -3,7 +3,6 @@ from collections.abc import Mapping import pytest from graphon.entities import GraphInitParams from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import GraphRuntimeState, VariablePool @@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - }, - } - ) +def _build_node_config() -> dict[str, object]: + return { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + ), + } + + +def _build_node_data() -> _SampleNodeData: + return _build_node_config()["data"] # type: ignore[return-value] def test_node_hydrates_data_during_initialization(): @@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization(): init_params, runtime_state = _build_context(graph_config) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum(): ) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error(): def test_base_node_data_keeps_dict_style_access_compatibility(): - node_data = _SampleNodeData.model_validate( - { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - } - ) + node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar") assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" @@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): def test_node_hydration_preserves_compatibility_extra_fields(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) - node_config = NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - "compat_flag": True, - }, - } - ) + node_config = { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + compat_flag=True, + ), + } node = _SampleNode( - id="node-1", - config=node_config, + node_id="node-1", + config=node_config["data"], graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 555ff0c945..896ec61092 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -9,14 +9,16 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, + _extract_text_from_file, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from graphon.variables import ArrayFileSegment +from graphon.variables import ArrayFileSegment, FileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable @@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params): title="Test Document Extractor", variable_selector=["node_id", "variable_name"], ) - node_config = {"id": "test_node_id", "data": node_data.model_dump()} http_client = Mock() node = DocumentExtractorNode( - id="test_node_id", - config=node_config, + node_id="test_node_id", + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=Mock(), http_client=http_client, @@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"] - mock_excel_instance.parse.side_effect = [df, Exception("Parse error")] + mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_mixed_content" @@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"] - mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")] + mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_all_bad_sheets" @@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert mock_excel_instance.parse.call_count == 2 +@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook")) +def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file): + with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"): + _extract_text_from_excel(b"broken") + + @patch("pandas.ExcelFile") def test_extract_text_from_excel_numeric_type_column(mock_excel_file): """Test extracting text from Excel file with numeric column names.""" @@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): assert expected_manual == result +@pytest.mark.parametrize( + ("extension", "mime_type"), + [ + (".xlsx", "text/plain"), + (None, "application/vnd.ms-excel"), + ], +) +def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type): + file = Mock(spec=File) + file.extension = extension + file.mime_type = mime_type + + with ( + patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"excel", + ), + patch( + "graphon.nodes.document_extractor.node._extract_text_from_excel", + return_value="excel text", + ) as mock_extract, + ): + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + assert result == "excel text" + mock_extract.assert_called_once_with(b"excel") + + +def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): + file = Mock(spec=File) + file.extension = None + file.mime_type = None + + with patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"unknown", + ): + with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"): + _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + +def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_list = Mock(spec=ArrayFileSegment) + file_list.value = [Mock(spec=File)] + mock_graph_runtime_state.variable_pool.get.return_value = file_list + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("bad file"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "bad file" + + +def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("single file failed"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "single file failed" + + +def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + return_value="single file text", + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs == {"text": "single file text"} + + def _make_docx_zip(use_backslash: bool) -> bytes: """Helper to build a minimal in-memory DOCX zip. diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 1b14f0ab13..18cb136322 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -19,6 +19,20 @@ from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params +def _build_if_else_node( + *, + node_data: IfElseNodeData | dict[str, object], + init_params, + graph_runtime_state, +) -> IfElseNode: + return IfElseNode( + node_id=str(uuid.uuid4()), + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), + ) + + def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} @@ -61,9 +75,8 @@ def test_execute_if_else_result_true(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "and", @@ -104,13 +117,8 @@ def test_execute_if_else_result_true(): {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -155,9 +163,8 @@ def test_execute_if_else_result_false(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "or", @@ -174,13 +181,8 @@ def test_execute_if_else_result_false(): }, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -222,11 +224,6 @@ def test_array_file_contains_file_name(): ], ) - node_config = { - "id": "if-else", - "data": node_data.model_dump(), - } - # Create properly configured mock for graph_init_params graph_init_params = Mock() graph_init_params.workflow_id = "test_workflow" @@ -242,17 +239,12 @@ def test_array_file_contains_file_name(): } } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=graph_init_params, - graph_runtime_state=Mock(), - config=node_config, - ) + node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock()) node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", @@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition): "logical_operator": "and", "conditions": [condition.model_dump()], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() @@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions(): ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={ - "id": "if-else", - "data": node_data, - }, ) # Mock db.session.close() @@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure(): } ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index d28c3e01e5..e7bf78f851 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -19,6 +19,15 @@ from graphon.variables import ArrayFileSegment from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode: + return ListOperatorNode( + node_id="test_node_id", + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=MagicMock(), + ) + + @pytest.fixture def list_operator_node(): config = { @@ -35,10 +44,6 @@ def list_operator_node(): "title": "Test Title", } node_data = ListOperatorNodeData.model_validate(config) - node_config = { - "id": "test_node_id", - "data": node_data.model_dump(), - } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() graph_init_params.workflow_id = "test_workflow" @@ -54,12 +59,7 @@ def list_operator_node(): } } - node = ListOperatorNode( - id="test_node_id", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=MagicMock(), - ) + node = _build_list_operator_node(node_data, graph_init_params) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node @@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node): files = [ File( filename="image1.jpg", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", ), File( filename="document1.pdf", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", ), File( filename="image2.png", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", ), File( filename="audio1.mp3", - type=FileType.AUDIO, + file_type=FileType.AUDIO, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", extension=".txt", @@ -156,7 +156,7 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, extension=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 833c303052..56bbb923b3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables): inputs=user_inputs, ) - config = { - "id": "start", - "data": StartNodeData(title="Start", variables=variables).model_dump(), - } + node_data = StartNodeData(title="Start", variables=variables) graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, @@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables): ) return StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, @@ -109,7 +106,7 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"): + with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"): node._run() @@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot(): inputs={"profile": {"age": 20, "name": "Tom"}}, ) - config = { - "id": "start", - "data": StartNodeData( - title="Start", - variables=[ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ], - ).model_dump(), - } + node_data = StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index c806181340..284af68319 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -13,6 +13,7 @@ from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment @@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode: runtime = _StubToolRuntime() node = ToolNode( - id="node-instance", - config=config, + node_id="node-instance", + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, @@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode: return node -def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: +def _collect_events(generator: Generator) -> list[Any]: events: list[Any] = [] try: while True: events.append(next(generator)) - except StopIteration as stop: - return events, stop.value + except StopIteration: + return events def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: @@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li node_id=tool_node._node_id, tool_runtime=ToolRuntimeHandle(raw=object()), ) - return _collect_events(generator) + events = _collect_events(generator) + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert completed_events + return events, completed_events[-1].node_run_result.llm_usage def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", @@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index c8ddc53284..e3b5e3b591 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,10 +1,10 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.runtime import GraphRuntimeState from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -27,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": TRIGGER_PLUGIN_NODE_TYPE, - "title": "Trigger Event", - "plugin_id": "plugin-id", - "provider_id": "provider-id", - "event_name": "event-name", - "subscription_id": "subscription-id", - "plugin_unique_identifier": "plugin-unique-identifier", - "event_parameters": {}, - }, - } +def _build_node_data() -> TriggerEventNodeData: + return TriggerEventNodeData( + type=TRIGGER_PLUGIN_NODE_TYPE, + title="Trigger Event", + plugin_id="plugin-id", + provider_id="provider-id", + event_name="event-name", + subscription_id="subscription-id", + plugin_unique_identifier="plugin-unique-identifier", + event_parameters={}, ) def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: init_params, runtime_state = _build_context(graph_config={}) node = TriggerEventNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 1bbc12b23f..07d03bec05 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -30,11 +30,6 @@ def create_webhook_node( tenant_id: str = "test-tenant", ) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "webhook-node-1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="test-workflow", graph_config={}, @@ -56,8 +51,8 @@ def create_webhook_node( ) node = TriggerWebhookNode( - id="webhook-node-1", - config=node_config, + node_id="webhook-node-1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -66,10 +61,6 @@ def create_webhook_node( runtime_state.app_config = Mock() runtime_state.app_config.tenant_id = tenant_id - # Provide compatibility alias expected by node implementation - # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests - node.node_id = node.id - return node diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 427afa96ec..b839490d3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -24,11 +24,6 @@ from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="1", graph_config={}, @@ -48,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) start_at=0, ) node = TriggerWebhookNode( - id="1", - config=node_config, + node_id="1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -57,9 +52,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) # Provide tenant_id for conversion path runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() - # Compatibility alias for some nodes referencing `self.node_id` - node.node_id = node.id - return node @@ -225,7 +217,7 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="image.jpg", @@ -234,7 +226,7 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", filename="document.pdf", @@ -269,8 +261,19 @@ def test_webhook_node_run_with_file_params(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() @@ -284,7 +287,7 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="test.jpg", @@ -317,8 +320,19 @@ def test_webhook_node_run_mixed_parameters(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py new file mode 100644 index 0000000000..8b5fceeb37 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py @@ -0,0 +1,350 @@ +from types import SimpleNamespace + +import pytest +from pydantic import BaseModel + +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + adapt_human_input_node_data_for_graph, + adapt_node_config_for_graph, + adapt_node_data_for_graph, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert " Team") - html = EmailDeliveryConfig.render_markdown_body( - "**Hello** [mail](mailto:test@example.com)" - ) - - assert rendered == "Open https://example.com and use 42" - assert sanitized == "Hello alert(1) Team" - assert "Hello" in html - assert " + +` + +export const getChromePluginContent = (iframeUrl: string) => `ChatBot URL: ${iframeUrl}` + +export const compressAndEncodeBase64 = async (input: string) => { + const uint8Array = new TextEncoder().encode(input) + if (typeof CompressionStream === 'undefined') + return btoa(String.fromCharCode(...uint8Array)) + + const compressedStream = new Response( + new Blob([uint8Array]) + .stream() + .pipeThrough(new CompressionStream('gzip')), + ).arrayBuffer() + const compressedUint8Array = new Uint8Array(await compressedStream) + return btoa(String.fromCharCode(...compressedUint8Array)) +} + export const getAppCardDisplayState = ({ appInfo, cardType, diff --git a/web/app/components/app/overview/app-card.tsx b/web/app/components/app/overview/app-card.tsx index f02929c823..fc0132c45a 100644 --- a/web/app/components/app/overview/app-card.tsx +++ b/web/app/components/app/overview/app-card.tsx @@ -1,13 +1,14 @@ 'use client' +import type { WorkflowLaunchInputValue } from './app-card-utils' import type { ConfigParams } from './settings' import type { AppDetailResponse } from '@/models/app' import type { AppSSO } from '@/types/app' +import { Switch } from '@langgenius/dify-ui/switch' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import AppBasic from '@/app/components/app-sidebar/basic' import { useStore as useAppStore } from '@/app/components/app/store' -import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-button' import Indicator from '@/app/components/header/indicator' @@ -27,11 +28,16 @@ import { AppCardOperations, AppCardUrlSection, createAppCardOperations, + WorkflowLaunchDialog, } from './app-card-sections' import { + buildWorkflowLaunchUrl, + createWorkflowLaunchInitialValues, getAppCardDisplayState, getAppCardOperationKeys, + getAppHiddenLaunchVariables, isAppAccessConfigured, + isWorkflowLaunchInputSupported, } from './app-card-utils' export type IAppCardProps = { @@ -62,7 +68,8 @@ function AppCard({ const router = useRouter() const pathname = usePathname() const { isCurrentWorkspaceManager, isCurrentWorkspaceEditor } = useAppContext() - const { data: currentWorkflow } = useAppWorkflow(appInfo.mode === AppModeEnum.WORKFLOW ? appInfo.id : '') + const shouldFetchWorkflow = appInfo.mode === AppModeEnum.WORKFLOW || appInfo.mode === AppModeEnum.ADVANCED_CHAT + const { data: currentWorkflow } = useAppWorkflow(shouldFetchWorkflow ? appInfo.id : '') const docLink = useDocLink() const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(state => state.setAppDetail) @@ -72,6 +79,8 @@ function AppCard({ const [genLoading, setGenLoading] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) const [showAccessControl, setShowAccessControl] = useState(false) + const [showWorkflowLaunchDialog, setShowWorkflowLaunchDialog] = useState(false) + const [workflowLaunchValues, setWorkflowLaunchValues] = useState>({}) const { t } = useTranslation() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const { data: appAccessSubjects } = useAppWhiteListSubjects( @@ -97,6 +106,25 @@ function AppCard({ () => isAppAccessConfigured(appDetail, appAccessSubjects), [appAccessSubjects, appDetail], ) + const hiddenLaunchVariables = useMemo( + () => getAppHiddenLaunchVariables({ + appInfo, + currentWorkflow, + }) || [], + [appInfo, currentWorkflow], + ) + const supportedWorkflowLaunchVariables = useMemo( + () => hiddenLaunchVariables.filter(isWorkflowLaunchInputSupported), + [hiddenLaunchVariables], + ) + const unsupportedWorkflowLaunchVariables = useMemo( + () => hiddenLaunchVariables.filter(variable => !isWorkflowLaunchInputSupported(variable)), + [hiddenLaunchVariables], + ) + const initialWorkflowLaunchValues = useMemo( + () => createWorkflowLaunchInitialValues(supportedWorkflowLaunchVariables), + [supportedWorkflowLaunchVariables], + ) const onGenCode = async () => { if (!onGenerateCode) @@ -138,6 +166,31 @@ function AppCard({ window.open(cardState.accessibleUrl, '_blank') }, [cardState.accessibleUrl]) + const handleOpenWorkflowLaunchDialog = useCallback(() => { + setWorkflowLaunchValues(initialWorkflowLaunchValues) + setShowWorkflowLaunchDialog(true) + }, [initialWorkflowLaunchValues]) + + const handleWorkflowLaunchValueChange = useCallback((variable: string, value: WorkflowLaunchInputValue) => { + setWorkflowLaunchValues(prev => ({ + ...prev, + [variable]: value, + })) + }, []) + + const handleWorkflowLaunchConfirm = useCallback(async (event: React.FormEvent) => { + event.preventDefault() + + const targetUrl = await buildWorkflowLaunchUrl({ + accessibleUrl: cardState.accessibleUrl, + variables: supportedWorkflowLaunchVariables, + values: workflowLaunchValues, + }) + + window.open(targetUrl, '_blank') + setShowWorkflowLaunchDialog(false) + }, [cardState.accessibleUrl, supportedWorkflowLaunchVariables, workflowLaunchValues]) + const handleOpenCustomize = useCallback(() => { setShowCustomizeModal(true) }, []) @@ -279,7 +332,17 @@ function AppCard({ {!cardState.isMinimalState && (
{!isApp && } - + 0 + ? { + label: t('operation.config', { ns: 'common' }), + disabled: triggerModeDisabled || !cardState.runningStatus, + onClick: handleOpenWorkflowLaunchDialog, + } + : undefined} + />
)} @@ -298,6 +361,17 @@ function AppCard({ onCloseAccessControl={() => setShowAccessControl(false)} onSaveSiteConfig={onSaveSiteConfig} onConfirmAccessControl={handleAccessControlUpdate} + hiddenInputs={hiddenLaunchVariables} + /> + ) diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx index b1d7066779..e4b8aabf69 100644 --- a/web/app/components/app/overview/customize/index.tsx +++ b/web/app/components/app/overview/customize/index.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC } from 'react' import { ArrowTopRightOnSquareIcon } from '@heroicons/react/24/outline' +import { Button } from '@langgenius/dify-ui/button' import * as React from 'react' import { useTranslation } from 'react-i18next' import Modal from '@/app/components/base/modal' import Tag from '@/app/components/base/tag' -import { Button } from '@/app/components/base/ui/button' import { useDocLink } from '@/context/i18n' import { AppModeEnum } from '@/types/app' diff --git a/web/app/components/app/overview/embedded/__tests__/index.spec.tsx b/web/app/components/app/overview/embedded/__tests__/index.spec.tsx index 0a843c26fd..ac4574b5f4 100644 --- a/web/app/components/app/overview/embedded/__tests__/index.spec.tsx +++ b/web/app/components/app/overview/embedded/__tests__/index.spec.tsx @@ -1,10 +1,11 @@ import type { SiteInfo } from '@/models/share' -import { fireEvent, render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' import copy from 'copy-to-clipboard' import * as React from 'react' - import { act } from 'react' -import { afterAll, afterEach, describe, expect, it, vi } from 'vitest' + +import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { InputVarType } from '@/app/components/workflow/types' import Embedded from '../index' vi.mock('../style.module.css', () => ({ @@ -46,6 +47,7 @@ vi.mock('@/context/app-context', () => ({ })) const mockWindowOpen = vi.spyOn(window, 'open').mockImplementation(() => null) const mockedCopy = vi.mocked(copy) +const originalCompressionStream = globalThis.CompressionStream const siteInfo: SiteInfo = { title: 'test site', @@ -70,6 +72,22 @@ const getCopyButton = () => { } describe('Embedded', () => { + beforeAll(() => { + class MockCompressionStream { + readable: ReadableStream + writable: WritableStream + + constructor() { + const transformStream = new TransformStream() + this.readable = transformStream.readable + this.writable = transformStream.writable + } + } + + // @ts-expect-error test polyfill + globalThis.CompressionStream = MockCompressionStream + }) + afterEach(() => { vi.clearAllMocks() mockWindowOpen.mockClear() @@ -77,6 +95,7 @@ describe('Embedded', () => { afterAll(() => { mockWindowOpen.mockRestore() + globalThis.CompressionStream = originalCompressionStream }) it('builds theme and copies iframe snippet', async () => { @@ -84,14 +103,20 @@ describe('Embedded', () => { render() }) + await waitFor(() => { + expect(screen.getByText((content, node) => node?.tagName.toLowerCase() === 'pre' && content.includes('/chatbot/token'))).toBeInTheDocument() + }) + const actionButton = getCopyButton() const innerDiv = actionButton.querySelector('div') - act(() => { + await act(async () => { fireEvent.click(innerDiv ?? actionButton) }) expect(mockThemeBuilder.buildTheme).toHaveBeenCalledWith(siteInfo.chat_color_theme, siteInfo.chat_color_theme_inverted) - expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token')) + await waitFor(() => { + expect(mockedCopy).toHaveBeenCalledWith(expect.stringContaining('/chatbot/token')) + }) }) it('opens chrome plugin store link when chrome option selected', async () => { @@ -116,4 +141,53 @@ describe('Embedded', () => { 'noopener,noreferrer', ) }) + + it('keeps hidden inputs collapsed by default and updates iframe and script content when values change', async () => { + render( + , + ) + + expect(screen.queryByLabelText('Secret')).not.toBeInTheDocument() + + await act(async () => { + fireEvent.click(screen.getByText('appOverview.overview.appInfo.embedded.hiddenInputs.title').closest('button')!) + }) + + await waitFor(() => { + expect(screen.getByLabelText('Secret')).toBeInTheDocument() + }) + + await act(async () => { + fireEvent.change(screen.getByLabelText('Secret'), { + target: { value: 'top-secret' }, + }) + }) + + expect(document.querySelector('pre')?.textContent ?? '').toContain('/chatbot/token') + + await waitFor(() => { + const codeBlock = document.querySelector('pre') + expect(codeBlock?.textContent ?? '').toContain('/chatbot/token?secret=dG9wLXNlY3JldA%3D%3D') + }) + + const optionButtons = document.body.querySelectorAll('[class*="option"]') + act(() => { + fireEvent.click(optionButtons[1]) + }) + + await waitFor(() => { + const codeBlock = document.querySelector('pre') + expect(codeBlock?.textContent ?? '').toContain('secret: "top-secret"') + }) + }) }) diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index 029566587e..f19f3aba80 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -1,90 +1,48 @@ +import type { MutableRefObject } from 'react' +import type { WorkflowHiddenStartVariable, WorkflowLaunchInputValue } from '../app-card-utils' import type { SiteInfo } from '@/models/share' import { cn } from '@langgenius/dify-ui/cn' import { + RiArrowDownSLine, + RiArrowRightSLine, RiClipboardFill, RiClipboardLine, } from '@remixicon/react' import copy from 'copy-to-clipboard' -import * as React from 'react' -import { useEffect, useState } from 'react' +import { Suspense, use, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import ActionButton from '@/app/components/base/action-button' import { useThemeContext } from '@/app/components/base/chat/embedded-chatbot/theme/theme-context' import Modal from '@/app/components/base/modal' import Tooltip from '@/app/components/base/tooltip' -import { IS_CE_EDITION } from '@/config' +import { InputVarType } from '@/app/components/workflow/types' import { useAppContext } from '@/context/app-context' import { basePath } from '@/utils/var' +import { + compressAndEncodeBase64, + createWorkflowLaunchInitialValues, + getChromePluginContent, + getEmbeddedIframeSnippet, + getEmbeddedScriptSnippet, + isWorkflowLaunchInputSupported, +} from '../app-card-utils' +import WorkflowHiddenInputFields from '../workflow-hidden-input-fields' import style from './style.module.css' type Props = { siteInfo?: SiteInfo isShow: boolean onClose: () => void - accessToken: string - appBaseUrl: string + accessToken?: string + appBaseUrl?: string + hiddenInputs?: WorkflowHiddenStartVariable[] className?: string } -const OPTION_MAP = { - iframe: { - getContent: (url: string, token: string) => - ``, - }, - scripts: { - getContent: (url: string, token: string, primaryColor: string, isTestEnv?: boolean) => - ` - -`, - }, - chromePlugin: { - getContent: (url: string, token: string) => `ChatBot URL: ${url}${basePath}/chatbot/${token}`, - }, -} +const OPTION_KEYS = ['iframe', 'scripts', 'chromePlugin'] as const const prefixEmbedded = 'overview.appInfo.embedded' -type Option = keyof typeof OPTION_MAP +type Option = typeof OPTION_KEYS[number] type OptionStatus = { iframe: boolean @@ -92,58 +50,192 @@ type OptionStatus = { chromePlugin: boolean } -const Embedded = ({ siteInfo, isShow, onClose, appBaseUrl, accessToken, className }: Props) => { +const EMPTY_OPTION_STATUS: OptionStatus = { + iframe: false, + scripts: false, + chromePlugin: false, +} + +const getSerializedHiddenInputValue = ( + variable: WorkflowHiddenStartVariable, + values: Record, +) => { + const rawValue = values[variable.variable] + if (variable.type === InputVarType.checkbox) + return String(Boolean(rawValue)) + + return String(rawValue ?? '') +} + +const buildEmbeddedIframeUrl = async ({ + appBaseUrl, + accessToken, + variables, + values, +}: { + appBaseUrl: string + accessToken: string + variables: WorkflowHiddenStartVariable[] + values: Record +}) => { + const iframeUrl = new URL(`${appBaseUrl}${basePath}/chatbot/${accessToken}`, window.location.origin) + + await Promise.all(variables.map(async (variable) => { + iframeUrl.searchParams.set(variable.variable, await compressAndEncodeBase64(getSerializedHiddenInputValue(variable, values))) + })) + + return iframeUrl.toString() +} + +const AsyncEmbeddedOptionContent = ({ + option, + iframeUrlPromise, + latestResolvedIframeUrlRef, +}: { + option: Option + iframeUrlPromise: Promise + latestResolvedIframeUrlRef: MutableRefObject +}) => { + const iframeUrl = use(iframeUrlPromise) + latestResolvedIframeUrlRef.current = iframeUrl + + if (option === 'chromePlugin') + return getChromePluginContent(iframeUrl) + + return getEmbeddedIframeSnippet(iframeUrl) +} + +const EmbeddedContent = ({ + siteInfo, + appBaseUrl, + accessToken, + hiddenInputs, +}: Required> & Pick) => { const { t } = useTranslation() + const supportedHiddenInputs = useMemo( + () => (hiddenInputs ?? []).filter(isWorkflowLaunchInputSupported), + [hiddenInputs], + ) + const initialHiddenInputValues = useMemo( + () => createWorkflowLaunchInitialValues(supportedHiddenInputs), + [supportedHiddenInputs], + ) const [option, setOption] = useState